.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_dmmot.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_dmmot.py: =============================================================================== Computing d-dimensional Barycenters via d-MMOT =============================================================================== When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and minimize the distance between many distributions without the need for intermediate barycenter computations. This example compares the time to identify, and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm and classical LP barycenter approaches. .. GENERATED FROM PYTHON SOURCE LINES 13-19 .. code-block:: Python # Author: Ronak Mehta # Xizheng Yu # # License: MIT License .. GENERATED FROM PYTHON SOURCE LINES 20-22 Generating 2 distributions ----- .. GENERATED FROM PYTHON SOURCE LINES 22-42 .. code-block:: Python import numpy as np import matplotlib.pyplot as pl import ot np.random.seed(0) n = 100 d = 2 # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) A = np.vstack((a1, a2)).T x = np.arange(n, dtype=np.float64) M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') pl.figure(1, figsize=(6.4, 3)) pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_001.png :alt: plot dmmot :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 43-47 Minimize the distances among distributions, identify the Barycenter ----- The objective being minimized is different for both methods, so the objective values cannot be compared. .. GENERATED FROM PYTHON SOURCE LINES 47-68 .. code-block:: Python # L2 Iteration weights = np.ones(d) / d l2_bary = A.dot(weights) print('LP Iterations:') weights = np.ones(d) / d lp_bary, lp_log = ot.lp.barycenter( A, M, weights, solver='interior-point', verbose=False, log=True) print('Time\t: ', ot.toc('')) print('Obj\t: ', lp_log['fun']) print('') print('Discrete MMOT Algorithm:') ot.tic() barys, log = ot.lp.dmmot_monge_1dgrid_optimize( A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) dmmot_obj = log['primal objective'] print('Time\t: ', ot.toc('')) print('Obj\t: ', dmmot_obj) .. rst-class:: sphx-glr-script-out .. code-block:: none LP Iterations: /home/circleci/project/ot/lp/cvx.py:125: OptimizeWarning: Sparse constraint matrix detected; setting 'sparse':True. sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, Time : 304.96243953704834 Obj : 19.999774740910517 Discrete MMOT Algorithm: Inital: Obj: 39.9995 GradNorm: 739.7831 Iter 0: Obj: 39.9995 GradNorm: 739.7831 Iter 100: Obj: 2.0914 GradNorm: 180.6322 Iter 200: Obj: 1.0583 GradNorm: 434.3777 Iter 300: Obj: 0.4220 GradNorm: 252.9269 Iter 400: Obj: 0.2317 GradNorm: 168.8668 Iter 500: Obj: 0.2116 GradNorm: 384.2968 Iter 600: Obj: 0.1755 GradNorm: 647.6758 Iter 700: Obj: 0.1343 GradNorm: 786.2442 Iter 800: Obj: 0.1021 GradNorm: 810.3703 Iter 900: Obj: 0.0662 GradNorm: 810.3703 Iter 1000: Obj: 0.0539 GradNorm: 741.7304 Iter 1100: Obj: 0.0348 GradNorm: 621.4660 Iter 1200: Obj: 0.0338 GradNorm: 764.3429 Iter 1300: Obj: 0.0200 GradNorm: 556.2338 Iter 1400: Obj: 0.0182 GradNorm: 765.8329 Iter 1500: Obj: 0.0103 GradNorm: 579.8241 Iter 1600: Obj: 0.0075 GradNorm: 638.2570 Iter 1700: Obj: 0.0045 GradNorm: 320.1562 Iter 1800: Obj: 0.0035 GradNorm: 479.8625 Iter 1900: Obj: 0.0032 GradNorm: 647.1939 Iter 2000: Obj: 0.0022 GradNorm: 442.4975 Iter 2100: Obj: 0.0015 GradNorm: 61.0901 Iter 2200: Obj: 0.0016 GradNorm: 464.9430 Iter 2300: Obj: 0.0014 GradNorm: 382.5650 Iter 2400: Obj: 0.0011 GradNorm: 287.2281 Iter 2500: Obj: 0.0011 GradNorm: 355.6796 Iter 2600: Obj: 0.0010 GradNorm: 280.1357 Iter 2700: Obj: 0.0010 GradNorm: 289.6964 Iter 2800: Obj: 0.0010 GradNorm: 184.4234 Iter 2900: Obj: 0.0009 GradNorm: 246.5847 Iter 3000: Obj: 0.0009 GradNorm: 65.3299 Iter 3100: Obj: 0.0009 GradNorm: 185.9355 Iter 3200: Obj: 0.0009 GradNorm: 263.0209 Iter 3300: Obj: 0.0009 GradNorm: 300.3132 Iter 3400: Obj: 0.0009 GradNorm: 231.4044 Iter 3500: Obj: 0.0009 GradNorm: 226.3184 Iter 3600: Obj: 0.0009 GradNorm: 211.4237 Iter 3700: Obj: 0.0009 GradNorm: 233.2981 Iter 3800: Obj: 0.0009 GradNorm: 299.0853 Iter 3900: Obj: 0.0009 GradNorm: 262.4271 Time : 4.240612030029297 Obj : 0.0008940778156514197 .. GENERATED FROM PYTHON SOURCE LINES 69-71 Compare Barycenters in both methods ----- .. GENERATED FROM PYTHON SOURCE LINES 71-86 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(len(barys)): if i == 0: pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: continue # pl.plot(x, barys[i], 'g-*') pl.plot(x, lp_bary, label='LP Barycenter') pl.plot(x, l2_bary, label='L2 Barycenter') pl.plot(x, a1, 'b', label='Source distribution') pl.plot(x, a2, 'r', label='Target distribution') pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver') pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_002.png :alt: Monge Cost: Barycenters from LP Solver and dmmot solver :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 87-90 More than 2 distributions -------------------------------------------------- Generate 7 pseudorandom gaussian distributions with 50 bins. .. GENERATED FROM PYTHON SOURCE LINES 90-111 .. code-block:: Python n = 50 # nb bins d = 7 vecsize = n * d data = [] for i in range(d): m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) a = ot.datasets.make_1D_gauss(n, m=m, s=5) data.append(a) x = np.arange(n, dtype=np.float64) M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') A = np.vstack(data).T pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) pl.title('Distributions') pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_003.png :alt: Distributions :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. .. GENERATED FROM PYTHON SOURCE LINES 112-116 Minimizing Distances Among Many Distributions --------------- The objective being minimized is different for both methods, so the objective values cannot be compared. .. GENERATED FROM PYTHON SOURCE LINES 116-130 .. code-block:: Python # Perform gradient descent optimization using the d-MMOT method. barys = ot.lp.dmmot_monge_1dgrid_optimize( A, niters=3000, lr_init=1e-4, lr_decay=0.997) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] # Compute 1D Wasserstein barycenter using the L2/LP method weights = ot.unif(d) l2_bary = A.dot(weights) lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=False, log=True) .. rst-class:: sphx-glr-script-out .. code-block:: none Inital: Obj: 37.1964 GradNorm: 284.3413 Iter 0: Obj: 37.1964 GradNorm: 280.9858 Iter 100: Obj: 3.3320 GradNorm: 136.2204 Iter 200: Obj: 0.7755 GradNorm: 156.3650 Iter 300: Obj: 0.4874 GradNorm: 229.6258 Iter 400: Obj: 0.3684 GradNorm: 238.1008 Iter 500: Obj: 0.3353 GradNorm: 280.3034 Iter 600: Obj: 0.2220 GradNorm: 267.4771 Iter 700: Obj: 0.1678 GradNorm: 284.3413 Iter 800: Obj: 0.1315 GradNorm: 284.3413 Iter 900: Obj: 0.0706 GradNorm: 271.7241 Iter 1000: Obj: 0.0567 GradNorm: 269.5960 Iter 1100: Obj: 0.0420 GradNorm: 250.8386 Iter 1200: Obj: 0.0345 GradNorm: 271.8676 Iter 1300: Obj: 0.0230 GradNorm: 230.8679 Iter 1400: Obj: 0.0145 GradNorm: 217.5960 Iter 1500: Obj: 0.0122 GradNorm: 244.9326 Iter 1600: Obj: 0.0089 GradNorm: 207.8076 Iter 1700: Obj: 0.0064 GradNorm: 175.3682 Iter 1800: Obj: 0.0054 GradNorm: 208.5713 Iter 1900: Obj: 0.0042 GradNorm: 196.3110 Iter 2000: Obj: 0.0035 GradNorm: 189.1930 Iter 2100: Obj: 0.0028 GradNorm: 152.9444 Iter 2200: Obj: 0.0025 GradNorm: 154.9903 Iter 2300: Obj: 0.0023 GradNorm: 165.5778 Iter 2400: Obj: 0.0020 GradNorm: 162.6161 Iter 2500: Obj: 0.0019 GradNorm: 148.9564 Iter 2600: Obj: 0.0018 GradNorm: 150.7780 Iter 2700: Obj: 0.0017 GradNorm: 160.8478 Iter 2800: Obj: 0.0017 GradNorm: 145.0931 Iter 2900: Obj: 0.0016 GradNorm: 128.1718 .. GENERATED FROM PYTHON SOURCE LINES 131-133 Compare Barycenters in both methods --------- .. GENERATED FROM PYTHON SOURCE LINES 133-140 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) pl.plot(x, bary, 'g-*', label='Discrete MMOT') pl.plot(x, l2_bary, 'k', label='L2 Barycenter') pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') pl.title('Barycenters') pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_004.png :alt: Barycenters :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 141-143 Compare with original distributions --------- .. GENERATED FROM PYTHON SOURCE LINES 143-158 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) for i in range(len(barys)): if i == 0: pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') else: continue # pl.plot(x, barys[i], 'g') pl.plot(x, l2_bary, 'k^', label='L2') pl.plot(x, lp_bary, 'o', color='grey', label='LP') pl.title('Barycenters') pl.legend() pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_005.png :alt: Barycenters :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 35.576 seconds) .. _sphx_glr_download_auto_examples_others_plot_dmmot.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_dmmot.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_dmmot.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_