.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/barycenters/plot_barycenter_lp_vs_entropic.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_barycenters_plot_barycenter_lp_vs_entropic.py: ================================================================================= 1D Wasserstein barycenter: exact LP vs entropic regularization ================================================================================= This example illustrates the computation of regularized Wasserstein Barycenter as proposed in [3] and exact LP barycenters using standard LP solver. It reproduces approximately Figure 3.1 and 3.2 from the following paper: Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems SIAM Journal on Scientific Computing, 37(2), A1111-A1138. .. GENERATED FROM PYTHON SOURCE LINES 19-35 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 4 import numpy as np import matplotlib.pylab as pl import ot # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection # noqa #import ot.lp.cvx as cvx .. GENERATED FROM PYTHON SOURCE LINES 36-38 Gaussian Data ------------- .. GENERATED FROM PYTHON SOURCE LINES 40-62 .. code-block:: Python problems = [] n = 100 # nb bins # bin positions x = np.arange(n, dtype=np.float64) # Gaussian distributions # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) # creating matrix A containing all distributions A = np.vstack((a1, a2)).T n_distributions = A.shape[1] # loss matrix + normalization M = ot.utils.dist0(n) M /= M.max() .. GENERATED FROM PYTHON SOURCE LINES 63-70 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.tight_layout() .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png :alt: Distributions :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 71-106 .. code-block:: Python alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) # l2bary bary_l2 = A.dot(weights) # wasserstein reg = 1e-3 ot.tic() bary_wass = ot.bregman.barycenter(A, M, reg, weights) ot.toc() ot.tic() bary_wass2 = ot.lp.barycenter(A, M, weights) ot.toc() pl.figure(2) pl.clf() pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.subplot(2, 1, 2) pl.plot(x, bary_l2, 'r', label='l2') pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') pl.legend() pl.title('Barycenters') pl.tight_layout() problems.append([A, [bary_l2, bary_wass, bary_wass2]]) .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png :alt: Distributions, Barycenters :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Elapsed time : 0.0034673213958740234 s Elapsed time : 0.2672250270843506 s .. GENERATED FROM PYTHON SOURCE LINES 107-109 Stair Data ---------- .. GENERATED FROM PYTHON SOURCE LINES 111-127 .. code-block:: Python a1 = 1.0 * (x > 10) * (x < 50) a2 = 1.0 * (x > 60) * (x < 80) a1 /= a1.sum() a2 /= a2.sum() # creating matrix A containing all distributions A = np.vstack((a1, a2)).T n_distributions = A.shape[1] # loss matrix + normalization M = ot.utils.dist0(n) M /= M.max() .. GENERATED FROM PYTHON SOURCE LINES 128-136 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.tight_layout() .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png :alt: Distributions :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 137-174 .. code-block:: Python alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) # l2bary bary_l2 = A.dot(weights) # wasserstein reg = 1e-3 ot.tic() bary_wass = ot.bregman.barycenter(A, M, reg, weights) ot.toc() ot.tic() bary_wass2 = ot.lp.barycenter(A, M, weights) ot.toc() problems.append([A, [bary_l2, bary_wass, bary_wass2]]) pl.figure(2) pl.clf() pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.subplot(2, 1, 2) pl.plot(x, bary_l2, 'r', label='l2') pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') pl.legend() pl.title('Barycenters') pl.tight_layout() .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png :alt: Distributions, Barycenters :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Elapsed time : 0.004792213439941406 s Elapsed time : 0.07863497734069824 s .. GENERATED FROM PYTHON SOURCE LINES 175-177 Dirac Data ---------- .. GENERATED FROM PYTHON SOURCE LINES 179-201 .. code-block:: Python a1 = np.zeros(n) a2 = np.zeros(n) a1[10] = .25 a1[20] = .5 a1[30] = .25 a2[80] = 1 a1 /= a1.sum() a2 /= a2.sum() # creating matrix A containing all distributions A = np.vstack((a1, a2)).T n_distributions = A.shape[1] # loss matrix + normalization M = ot.utils.dist0(n) M /= M.max() .. GENERATED FROM PYTHON SOURCE LINES 202-210 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.tight_layout() .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_005.png :alt: Distributions :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 211-248 .. code-block:: Python alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) # l2bary bary_l2 = A.dot(weights) # wasserstein reg = 1e-3 ot.tic() bary_wass = ot.bregman.barycenter(A, M, reg, weights) ot.toc() ot.tic() bary_wass2 = ot.lp.barycenter(A, M, weights) ot.toc() problems.append([A, [bary_l2, bary_wass, bary_wass2]]) pl.figure(2) pl.clf() pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) pl.title('Distributions') pl.subplot(2, 1, 2) pl.plot(x, bary_l2, 'r', label='l2') pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') pl.legend() pl.title('Barycenters') pl.tight_layout() .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png :alt: Distributions, Barycenters :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Elapsed time : 0.0010290145874023438 s Elapsed time : 0.0551142692565918 s .. GENERATED FROM PYTHON SOURCE LINES 249-252 Final figure ------------ .. GENERATED FROM PYTHON SOURCE LINES 254-289 .. code-block:: Python nbm = len(problems) nbm2 = (nbm // 2) pl.figure(2, (20, 6)) pl.clf() for i in range(nbm): A = problems[i][0] bary_l2 = problems[i][1][0] bary_wass = problems[i][1][1] bary_wass2 = problems[i][1][2] pl.subplot(2, nbm, 1 + i) for j in range(n_distributions): pl.plot(x, A[:, j]) if i == nbm2: pl.title('Distributions') pl.xticks(()) pl.yticks(()) pl.subplot(2, nbm, 1 + i + nbm) pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') if i == nbm - 1: pl.legend() if i == nbm2: pl.title('Barycenters') pl.xticks(()) pl.yticks(()) .. image-sg:: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_007.png :alt: Distributions, Barycenters :srcset: /auto_examples/barycenters/images/sphx_glr_plot_barycenter_lp_vs_entropic_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.533 seconds) .. _sphx_glr_download_auto_examples_barycenters_plot_barycenter_lp_vs_entropic.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_barycenter_lp_vs_entropic.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_barycenter_lp_vs_entropic.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_