Note
Go to the end to download the full example code
Gaussian Bures-Wasserstein barycenters
Illustration of Gaussian Bures-Wasserstein barycenters.
# Authors: Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
Define Gaussian Covariances and distributions
Plot the distributions
def draw_cov(mu, C, color=None, label=None, nstd=1):
def eigsorted(cov):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
return vals[order], vecs[:, order]
vals, vecs = eigsorted(C)
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
w, h = 2 * nstd * np.sqrt(vals)
ell = Ellipse(xy=(mu[0], mu[1]),
width=w, height=h, alpha=0.5,
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
pl.gca().add_artist(ell)
#pl.scatter(mu[0],mu[1],color=color, marker='x')
axis = [-1.5, 5.5, -1.5, 5.5]
pl.figure(1, (8, 2))
pl.clf()
pl.subplot(1, 4, 1)
draw_cov(m1, C1, color='C0')
pl.axis(axis)
pl.title('$\mathcal{N}(m_1,\Sigma_1)$')
pl.subplot(1, 4, 2)
draw_cov(m2, C2, color='C1')
pl.axis(axis)
pl.title('$\mathcal{N}(m_2,\Sigma_2)$')
pl.subplot(1, 4, 3)
draw_cov(m3, C3, color='C2')
pl.axis(axis)
pl.title('$\mathcal{N}(m_3,\Sigma_3)$')
pl.subplot(1, 4, 4)
draw_cov(m4, C4, color='C3')
pl.axis(axis)
pl.title('$\mathcal{N}(m_4,\Sigma_4)$')
Text(0.5, 1.0, '$\\mathcal{N}(m_4,\\Sigma_4)$')
Compute Bures-Wasserstein barycenters and plot them
# basis for bilinear interpolation
v1 = np.array((1, 0, 0, 0))
v2 = np.array((0, 1, 0, 0))
v3 = np.array((0, 0, 1, 0))
v4 = np.array((0, 0, 0, 1))
colors = np.stack((colors.to_rgb('C0'),
colors.to_rgb('C1'),
colors.to_rgb('C2'),
colors.to_rgb('C3')))
pl.figure(2, (8, 8))
nb_interp = 6
for i in range(nb_interp):
for j in range(nb_interp):
tx = float(i) / (nb_interp - 1)
ty = float(j) / (nb_interp - 1)
# weights are constructed by bilinear interpolation
tmp1 = (1 - tx) * v1 + tx * v2
tmp2 = (1 - tx) * v3 + tx * v4
weights = (1 - ty) * tmp1 + ty * tmp2
color = np.dot(colors.T, weights)
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)
draw_cov(mb, Cb, color=color, label=None, nstd=0.3)
pl.axis(axis)
pl.axis('off')
pl.tight_layout()
Total running time of the script: (0 minutes 0.332 seconds)