.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_EWCA.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_EWCA.py: ======================================= Entropic Wasserstein Component Analysis ======================================= This example illustrates the use of EWCA as proposed in [52]. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). Entropic Wasserstein Component Analysis. .. GENERATED FROM PYTHON SOURCE LINES 14-28 .. code-block:: Python # Author: Antoine Collas # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl from ot.dr import ewca from sklearn.datasets import make_blobs from matplotlib import ticker as mticker import matplotlib.patches as patches import matplotlib .. GENERATED FROM PYTHON SOURCE LINES 29-31 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 31-48 .. code-block:: Python n_samples = 20 esp = 0.8 centers = np.array([[esp, esp], [-esp, -esp]]) cluster_std = 0.4 rng = np.random.RandomState(42) X, y = make_blobs( n_samples=n_samples, n_features=2, centers=centers, cluster_std=cluster_std, shuffle=False, random_state=rng, ) X = X - X.mean(0) .. GENERATED FROM PYTHON SOURCE LINES 49-51 Plot data ------------- .. GENERATED FROM PYTHON SOURCE LINES 51-83 .. code-block:: Python fig = pl.figure(figsize=(4, 4)) cmap = matplotlib.colormaps.get_cmap("tab10") pl.scatter( X[: n_samples // 2, 0], X[: n_samples // 2, 1], color=[cmap(y[i] + 1) for i in range(n_samples // 2)], alpha=0.4, label="Class 1", zorder=30, s=50, ) pl.scatter( X[n_samples // 2:, 0], X[n_samples // 2:, 1], color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)], alpha=0.4, label="Class 2", zorder=30, s=50, ) x_y_lim = 2.5 fs = 15 pl.xlim(-x_y_lim, x_y_lim) pl.xticks([]) pl.ylim(-x_y_lim, x_y_lim) pl.yticks([]) pl.legend(fontsize=fs) pl.title("Data", fontsize=fs) pl.tight_layout() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_EWCA_001.png :alt: Data :srcset: /auto_examples/others/images/sphx_glr_plot_EWCA_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 84-86 Compute EWCA ------------- .. GENERATED FROM PYTHON SOURCE LINES 86-90 .. code-block:: Python pi, U = ewca(X, k=2, reg=0.5) .. GENERATED FROM PYTHON SOURCE LINES 91-93 Plot data, first component, and projected data ------------- .. GENERATED FROM PYTHON SOURCE LINES 93-151 .. code-block:: Python fig = pl.figure(figsize=(4, 4)) scale = 3 u = U[:, 0] pl.plot( [scale * u[0], -scale * u[0]], [scale * u[1], -scale * u[1]], color="grey", linestyle="--", lw=3, alpha=0.3, label=r"$\mathbf{U}$", ) X1 = X @ u[:, None] @ u[:, None].T for i in range(n_samples): for j in range(n_samples): v = pi[i, j] / pi.max() if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1): pl.plot( [X[i, 0], X1[j, 0]], [X[i, 1], X1[j, 1]], alpha=v, linestyle="-", c="C0", label=r"$\pi_{ij}$" if (i, j) == (n_samples - 1, n_samples - 1) else None, ) pl.scatter( X[:, 0], X[:, 1], color=[cmap(y[i] + 1) for i in range(n_samples)], alpha=0.4, label=r"$\mathbf{x}_i$", zorder=30, s=50, ) pl.scatter( X1[:, 0], X1[:, 1], color=[cmap(y[i] + 1) for i in range(n_samples)], alpha=0.9, s=50, marker="+", label=r"$\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_i$", zorder=30, ) pl.title("Data and projections", fontsize=fs) pl.xlim(-x_y_lim, x_y_lim) pl.xticks([]) pl.ylim(-x_y_lim, x_y_lim) pl.yticks([]) pl.legend(fontsize=fs, loc="upper left") pl.tight_layout() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_EWCA_002.png :alt: Data and projections :srcset: /auto_examples/others/images/sphx_glr_plot_EWCA_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 152-154 Plot transport plan ------------- .. GENERATED FROM PYTHON SOURCE LINES 154-188 .. code-block:: Python fig = pl.figure(figsize=(5, 5)) norm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100) im = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect="auto") cb = fig.colorbar(im, orientation="vertical", shrink=0.8) ticks_loc = cb.ax.get_yticks().tolist() cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) cb.ax.set_yticklabels([f"{int(i)}%" for i in cb.get_ticks()]) cb.ax.tick_params(labelsize=fs) for i, class_ in enumerate(np.sort(np.unique(y))): indices = y == class_ idx_min = np.min(np.arange(len(y))[indices]) idx_max = np.max(np.arange(len(y))[indices]) width = idx_max - idx_min + 1 rect = patches.Rectangle( (idx_min - 0.5, idx_min - 0.5), width, width, linewidth=1, edgecolor="r", facecolor="none", ) pl.gca().add_patch(rect) pl.title("OT plan", fontsize=fs) pl.ylabel(r"($\mathbf{x}_1, \cdots, \mathbf{x}_n$)") x_label = r"($\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_1, \cdots," x_label += r"\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_n$)" pl.xlabel(x_label) pl.tight_layout() pl.axis("scaled") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_EWCA_003.png :alt: OT plan :srcset: /auto_examples/others/images/sphx_glr_plot_EWCA_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.588 seconds) .. _sphx_glr_download_auto_examples_others_plot_EWCA.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_EWCA.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_EWCA.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_