.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/barycenters/plot_gmm_barycenter.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_barycenters_plot_gmm_barycenter.py: ===================================== Gaussian Mixture Model OT Barycenters ===================================== This example illustrates the computation of a barycenter between Gaussian Mixtures in the sense of GMM-OT [69]. This computation is done using the fixed-point method for OT barycenters with generic costs [77], for which POT provides a general solver, and a specific GMM solver. Note that this is a 'free-support' method, implying that the number of components of the barycenter GMM and their weights are fixed. The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the Bures-Wasserstein manifold), and to compute barycenters with respect to the 2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a gaussian mixture is a finite combination of Diracs on specific gaussians, and two mixtures are compared with the 2-Wasserstein distance on this space, where ground cost the squared Bures distance between gaussians. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. [77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaƫl (2024). Computing Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) .. GENERATED FROM PYTHON SOURCE LINES 30-37 .. code-block:: Python # Author: Eloi Tanguy # # License: MIT License # sphinx_gallery_thumbnail_number = 1 .. GENERATED FROM PYTHON SOURCE LINES 38-39 Generate data .. GENERATED FROM PYTHON SOURCE LINES 39-79 .. code-block:: Python import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Ellipse import ot from ot.gmm import gmm_barycenter_fixed_point K = 3 # number of GMMs d = 2 # dimension n = 6 # number of components of the desired barycenter def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): rng = np.random.RandomState(seed=seed) means = rng.randn(K, d) P = rng.randn(K, d, d) * cov_scale # C[k] = P[k] @ P[k]^T + min_cov_eig * I covariances = np.einsum("kab,kcb->kac", P, P) covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) weights = rng.random(K) weights /= np.sum(weights) return means, covariances, weights m_list = [5, 6, 7] # number of components in each GMM offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] means_list = [] # list of means for each GMM covs_list = [] # list of covariances for each GMM w_list = [] # list of weights for each GMM # generate GMMs for k in range(K): means, covs, b = get_random_gmm( m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 ) means = means / 2 + offsets[k][None, :] means_list.append(means) covs_list.append(covs) w_list.append(b) .. GENERATED FROM PYTHON SOURCE LINES 80-81 Compute the barycenter using the fixed-point method .. GENERATED FROM PYTHON SOURCE LINES 81-95 .. code-block:: Python init_means, init_covs, _ = get_random_gmm(n, d, seed=0) weights = ot.unif(K) # barycenter coefficients means_bar, covs_bar, log = gmm_barycenter_fixed_point( means_list, covs_list, w_list, init_means, init_covs, weights, iterations=3, log=True, ) .. GENERATED FROM PYTHON SOURCE LINES 96-97 Define plotting functions .. GENERATED FROM PYTHON SOURCE LINES 97-133 .. code-block:: Python # draw a covariance ellipse def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): def eigsorted(cov): vals, vecs = np.linalg.eigh(cov) order = vals.argsort()[::-1].copy() return vals[order], vecs[:, order] vals, vecs = eigsorted(C) theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) w, h = 2 * nstd * np.sqrt(vals) ell = Ellipse( xy=(mu[0], mu[1]), width=w, height=h, alpha=alpha, angle=theta, facecolor=color, edgecolor=color, label=label, fill=True, ) if ax is None: ax = plt.gca() ax.add_artist(ell) # draw a gmm as a set of ellipses with weights shown in alpha value def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): for k in range(ms.shape[0]): draw_cov( ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax ) .. GENERATED FROM PYTHON SOURCE LINES 134-135 Plot the results .. GENERATED FROM PYTHON SOURCE LINES 135-146 .. code-block:: Python c_list = ["#7ED321", "#4A90E2", "#9013FE", "#F5A623"] c_bar = "#D0021B" fig, ax = plt.subplots(figsize=(6, 6)) axis = [-4, 4, -2, 6] ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) for k in range(K): draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax) draw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax) ax.axis(axis) ax.axis("off") .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_gmm_barycenter_001.png :alt: Fixed Point Barycenter (3 Iterations) :srcset: /auto_examples/barycenters/images/sphx_glr_plot_gmm_barycenter_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-4.0), np.float64(4.0), np.float64(-2.0), np.float64(6.0)) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.096 seconds) .. _sphx_glr_download_auto_examples_barycenters_plot_gmm_barycenter.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gmm_barycenter.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gmm_barycenter.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gmm_barycenter.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_