GMM Plan 1D

Illustration of the GMM plan for the Mixture Wasserstein between two GMM in 1D, as well as the two maps T_mean and T_rand. T_mean is the barycentric projection of the GMM coupling, and T_rand takes a random gaussian image between two components, according to the coupling and the GMMs. See [69] for details. .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.

# Author: Eloi Tanguy <eloi.tanguy@u-paris>
#         Remi Flamary <remi.flamary@polytehnique.edu>
#         Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
from ot.plot import plot1D_mat, rescale_for_imshow_plot
from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map
import matplotlib.pyplot as plt

Generate GMMOT plan plot it

ks = 2
kt = 3
d = 1
eps = 0.1
m_s = np.array([[1], [2]])
m_t = np.array([[3], [4.2], [5]])
C_s = np.array([[[0.05]], [[0.06]]])
C_t = np.array([[[0.03]], [[0.07]], [[0.04]]])
w_s = np.array([0.4, 0.6])
w_t = np.array([0.4, 0.2, 0.4])

n = 500
a_x, b_x = 0, 3
x = np.linspace(a_x, b_x, n)
a_y, b_y = 2, 6
y = np.linspace(a_y, b_y, n)
plan_density = gmm_ot_plan_density(
    x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2
)

a = gmm_pdf(x[:, None], m_s, C_s, w_s)
b = gmm_pdf(y[:, None], m_t, C_t, w_t)
plt.figure(figsize=(8, 8))
plot1D_mat(
    a,
    b,
    plan_density,
    title="GMM OT plan",
    plot_style="xy",
    a_label="Source distribution",
    b_label="Target distribution",
)
Source distribution, Target distribution, GMM OT plan
(<Axes: title={'center': 'Source distribution'}>, <Axes: title={'center': 'Target distribution'}>, <Axes: title={'center': 'GMM OT plan'}>)

Generate GMMOT maps and plot them over plan

plt.figure(figsize=(8, 8))
ax_s, ax_t, ax_M = plot1D_mat(
    a,
    b,
    plan_density,
    plot_style="xy",
    title="GMM OT plan with T_mean and T_rand maps",
    a_label="Source distribution",
    b_label="Target distribution",
)
T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="bary")[:, 0]
x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, a_y=a_y, b_y=b_y)

ax_M.plot(
    x_rescaled, T_mean_rescaled, label="T_mean", alpha=0.5, linewidth=5, color="aqua"
)

T_rand = gmm_ot_apply_map(
    x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="rand", seed=0
)[:, 0]
x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, a_y=a_y, b_y=b_y)

ax_M.scatter(
    x_rescaled, T_rand_rescaled, label="T_rand", alpha=0.5, s=20, color="orange"
)

ax_M.legend(loc="upper left", fontsize=13)
Source distribution, Target distribution, GMM OT plan with T_mean and T_rand maps
<matplotlib.legend.Legend object at 0x7ff964c49ab0>

Total running time of the script: (0 minutes 0.375 seconds)

Gallery generated by Sphinx-Gallery