Source code for ot.gmm

# -*- coding: utf-8 -*-
"""
Optimal transport for Gaussian Mixtures
"""

# Author: Eloi Tanguy <eloi.tanguy@u-paris>
#         Remi Flamary <remi.flamary@polytehnique.edu>
#         Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

from .backend import get_backend
from .lp import emd2, emd
import numpy as np
from .lp import dist
from .gaussian import bures_wasserstein_mapping


[docs] def gaussian_pdf(x, m, C): r""" Compute the probability density function of a multivariate Gaussian distribution. Parameters ---------- x : array-like, shape (..., d) The input samples. m : array-like, shape (d,) The mean vector of the Gaussian distribution. C : array-like, shape (d, d) The covariance matrix of the Gaussian distribution. Returns ------- pdf : array-like, shape (...,) The probability density function evaluated at each sample. """ assert ( x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2] ), "Dimension mismatch" nx = get_backend(x, m, C) d = x.shape[-1] z = (2 * np.pi) ** (-d / 2) * nx.det(C) ** (-0.5) exp = nx.exp(-0.5 * nx.sum(((x - m) @ nx.inv(C)) * (x - m), axis=-1)) return z * exp
[docs] def gmm_pdf(x, m, C, w): r""" Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. Parameters ---------- x : array-like, shape (..., d) The input samples. m : array-like, shape (n_components, d) The means of the Gaussian components. C : array-like, shape (n_components, d, d) The covariance matrices of the Gaussian components. w : array-like, shape (n_components,) The weights of the Gaussian components. Returns ------- out : array-like, shape (...,) The PDF values at the given points. """ assert ( m.shape[0] == C.shape[0] == w.shape[0] ), "All GMM parameters must have the same amount of components" nx = get_backend(x, m, C, w) out = nx.zeros((x.shape[:-1])) for k in range(m.shape[0]): out = out + w[k] * gaussian_pdf(x, m[k], C[k]) return out
[docs] def dist_bures_squared(m_s, m_t, C_s, C_t): r""" Compute the matrix of the squared Bures distances between the components of two Gaussian Mixture Models (GMMs). Used to compute the GMM Optimal Transport distance [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. Returns ------- dist : array-like, shape (k_s, k_t) Matrix of squared Bures distances between the components of the source and target GMMs. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ nx = get_backend(m_s, C_s, m_t, C_t) assert m_s.shape[0] == C_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == C_t.shape[0], "Target GMM has different amount of components" assert ( m_s.shape[-1] == m_t.shape[-1] == C_s.shape[-1] == C_t.shape[-1] ), "All GMMs must have the same dimension" D_means = dist(m_s, m_t, metric="sqeuclidean") # C2[i, j] = Cs12[i] @ C_t[j] @ Cs12[i], shape (k_s, k_t, d, d) Cs12 = nx.sqrtm(C_s) # broadcasts matrix sqrt over (k_s,) C2 = nx.einsum("ikl,jlm,imn->ijkn", Cs12, C_t, Cs12) C = nx.sqrtm(C2) # broadcasts matrix sqrt over (k_s, k_t) # D_covs[i,j] = trace(C_s[i] + C_t[j] - 2C[i,j]) trace_C_s = nx.einsum("ikk->i", C_s)[:, None] # (k_s, 1) trace_C_t = nx.einsum("ikk->i", C_t)[None, :] # (1, k_t) D_covs = trace_C_s + trace_C_t # broadcasts to (k_s, k_t) D_covs -= 2 * nx.einsum("ijkk->ij", C) return nx.maximum(D_means + D_covs, 0)
[docs] def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between two GMMs introduced in [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the GMM optimal transportation cost. Returns ------- loss : float or array-like The GMM-OT loss. log : dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd2(w_s, w_t, D, log=log)
[docs] def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport plan between two GMMs introduced in [69]. Parameters ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. log : bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the GMM optimal transportation matrix. Returns ------- plan : array-like, shape (k_s, k_t) The GMM-OT plan. log : dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd(w_s, w_t, D, log=log)
[docs] def gmm_ot_apply_map( x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, method="bary", seed=None ): r""" Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. The 'barycentric' mapping corresponds to the barycentric projection of the GMM-OT plan, and is called T_bary in [69]. The 'random' mapping takes for each input point a random pair (i,j) of components of the GMMs and applied the Gaussian map, it is called T_rand in [69]. Parameters ---------- x : array-like, shape (n_samples, d) Input data points. m_s : array-like, shape (k_s, d) Mean vectors of the source GMM components. m_t : array-like, shape (k_t, d) Mean vectors of the target GMM components. C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM components. C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM components. w_s : array-like, shape (k_s,) Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. plan : array-like, shape (k_s, k_t), optional Optimal transport plan between the source and target GMM components. If not provided, it will be computed internally. method : {'bary', 'rand'}, optional Method for applying the GMM OT mapping. 'bary' uses barycentric mapping, while 'rand' uses random sampling. Default is 'bary'. seed : int, optional Seed for the random number generator. Only used when method='rand'. Returns ------- out : array-like, shape (n_samples, d) Output data points after applying the GMM OT mapping. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t) else: nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t, plan) k_s, k_t = m_s.shape[0], m_t.shape[0] d = m_s.shape[1] n_samples = x.shape[0] if method == "bary": normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) print("where plan > 0", nx.where(plan > 0)) # only need to compute for non-zero plan entries for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) g = gaussian_pdf(x, m_s[i], C_s[i])[:, None] M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) A = Cs12inv @ M0 @ Cs12inv b = m_t[j] - A @ m_s[i] # gaussian mapping between components i and j applied to x T_ij_x = x @ A + b out = out + plan[i, j] * g * T_ij_x return out / normalization else: # rand # A[i, j] is the linear part of the gaussian mapping between components # i and j, b[i, j] is the translation part rng = np.random.RandomState(seed) A = nx.zeros((k_s, k_t, d, d)) b = nx.zeros((k_s, k_t, d)) # only need to compute for non-zero plan entries for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) A[i, j] = Cs12inv @ M0 @ Cs12inv b[i, j] = m_t[j] - A[i, j] @ m_s[i] normalization = gmm_pdf(x, m_s, C_s, w_s) # (n_samples,) gs = np.stack([gaussian_pdf(x, m_s[i], C_s[i]) for i in range(k_s)], axis=-1) # (n_samples, k_s) out = nx.zeros(x.shape) for i_sample in range(n_samples): p_mat = plan * gs[i_sample][:, None] / normalization[i_sample] p = p_mat.reshape(k_s * k_t) # stack line-by-line # sample between 0 and k_s * k_t - 1 ij_mat = rng.choice(k_s * k_t, p=p) i = ij_mat // k_t j = ij_mat % k_t out[i_sample] = A[i, j] @ x[i_sample] + b[i, j] return out
[docs] def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): """ Compute the density of the Gaussian Mixture Model - Optimal Transport coupling between GMMS at given points, as introduced in [69]. Given two arrays of points x and y, the function computes the density at each point `(x[i], y[i])` of the product space. Parameters ---------- x : array-like, shape (n, d) Entry points in source space for density computation. y : array-like, shape (m, d) Entry points in target space for density computation. m_s : array-like, shape (k_s, d) The means of the source GMM components. m_t : array-like, shape (k_t, d) The means of the target GMM components. C_s : array-like, shape (k_s, d, d) The covariance matrices of the source GMM components. C_t : array-like, shape (k_t, d, d) The covariance matrices of the target GMM components. w_s : array-like, shape (k_s,) The weights of the source GMM components. w_t : array-like, shape (k_t,) The weights of the target GMM components. plan : array-like, shape (k_s, k_t), optional The optimal transport plan between the source and target GMMs. If not provided, it will be computed using `gmm_ot_plan`. atol : float, optional The absolute tolerance used to determine the support of the GMM-OT coupling. Returns ------- density : array-like, shape (n, m) The density of the GMM-OT coupling between the two GMMs. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ assert ( x.shape[-1] == y.shape[-1] ), "x (n, d) and y (m, d) must have the same dimension d" n, m = x.shape[0], y.shape[0] nx = get_backend(x, y, m_s, m_t, C_s, C_t, w_s, w_t) # hand-made d-variate meshgrid in ij indexing xx = x[:, None, :] * nx.ones((1, m, 1)) # shapes (n, m, d) yy = y[None, :, :] * nx.ones((n, 1, 1)) # shapes (n, m, d) if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) def Tk0k1(k0, k1): A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) Tx = xx @ A + b g = gaussian_pdf(xx, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = nx.norm(Tx - yy, axis=-1) out = out * ((norms < atol) * 1.0) return out mat = nx.stack( [ nx.stack([Tk0k1(k0, k1) for k1 in range(m_t.shape[0])]) for k0 in range(m_s.shape[0]) ] ) return nx.sum(mat, axis=(0, 1))