.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_dmmot.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_dmmot.py: =============================================================================== Computing d-dimensional Barycenters via d-MMOT =============================================================================== When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and minimize the distance between many distributions without the need for intermediate barycenter computations. This example compares the time to identify, and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm and classical LP barycenter approaches. .. GENERATED FROM PYTHON SOURCE LINES 13-19 .. code-block:: Python # Author: Ronak Mehta # Xizheng Yu # # License: MIT License .. GENERATED FROM PYTHON SOURCE LINES 20-22 Generating 2 distributions ----- .. GENERATED FROM PYTHON SOURCE LINES 22-42 .. code-block:: Python import numpy as np import matplotlib.pyplot as pl import ot np.random.seed(0) n = 100 d = 2 # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) A = np.vstack((a1, a2)).T x = np.arange(n, dtype=np.float64) M = ot.utils.dist(x.reshape((n, 1)), metric="minkowski") pl.figure(1, figsize=(6.4, 3)) pl.plot(x, a1, "b", label="Source distribution") pl.plot(x, a2, "r", label="Target distribution") pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_001.png :alt: plot dmmot :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 43-47 Minimize the distances among distributions, identify the Barycenter ----- The objective being minimized is different for both methods, so the objective values cannot be compared. .. GENERATED FROM PYTHON SOURCE LINES 47-70 .. code-block:: Python # L2 Iteration weights = np.ones(d) / d l2_bary = A.dot(weights) print("LP Iterations:") weights = np.ones(d) / d lp_bary, lp_log = ot.lp.barycenter( A, M, weights, solver="interior-point", verbose=False, log=True ) print("Time\t: ", ot.toc("")) print("Obj\t: ", lp_log["fun"]) print("") print("Discrete MMOT Algorithm:") ot.tic() barys, log = ot.lp.dmmot_monge_1dgrid_optimize( A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True ) dmmot_obj = log["primal objective"] print("Time\t: ", ot.toc("")) print("Obj\t: ", dmmot_obj) .. rst-class:: sphx-glr-script-out .. code-block:: none LP Iterations: /home/circleci/project/ot/lp/cvx.py:127: OptimizeWarning: Sparse constraint matrix detected; setting 'sparse':True. sol = sp.optimize.linprog( Time : 265.5730583667755 Obj : 19.999774737592773 Discrete MMOT Algorithm: Initial: Obj: 39.9995 GradNorm: 739.7831 Iter 0: Obj: 39.9995 GradNorm: 739.7831 Iter 100: Obj: 2.0914 GradNorm: 180.6322 Iter 200: Obj: 1.0583 GradNorm: 434.3777 Iter 300: Obj: 0.4220 GradNorm: 252.9269 Iter 400: Obj: 0.2317 GradNorm: 168.8668 Iter 500: Obj: 0.2116 GradNorm: 384.2968 Iter 600: Obj: 0.1755 GradNorm: 647.6758 Iter 700: Obj: 0.1343 GradNorm: 786.2442 Iter 800: Obj: 0.1021 GradNorm: 810.3703 Iter 900: Obj: 0.0662 GradNorm: 810.3703 Iter 1000: Obj: 0.0539 GradNorm: 741.7304 Iter 1100: Obj: 0.0348 GradNorm: 621.4660 Iter 1200: Obj: 0.0338 GradNorm: 764.3429 Iter 1300: Obj: 0.0200 GradNorm: 556.2338 Iter 1400: Obj: 0.0182 GradNorm: 765.8329 Iter 1500: Obj: 0.0103 GradNorm: 579.8241 Iter 1600: Obj: 0.0075 GradNorm: 638.2570 Iter 1700: Obj: 0.0045 GradNorm: 320.1562 Iter 1800: Obj: 0.0035 GradNorm: 479.8625 Iter 1900: Obj: 0.0032 GradNorm: 647.1939 Iter 2000: Obj: 0.0022 GradNorm: 442.4975 Iter 2100: Obj: 0.0015 GradNorm: 61.0901 Iter 2200: Obj: 0.0016 GradNorm: 464.9430 Iter 2300: Obj: 0.0014 GradNorm: 382.5650 Iter 2400: Obj: 0.0011 GradNorm: 287.2281 Iter 2500: Obj: 0.0011 GradNorm: 355.6796 Iter 2600: Obj: 0.0010 GradNorm: 280.1357 Iter 2700: Obj: 0.0010 GradNorm: 289.6964 Iter 2800: Obj: 0.0010 GradNorm: 184.4234 Iter 2900: Obj: 0.0009 GradNorm: 246.5847 Iter 3000: Obj: 0.0009 GradNorm: 65.3299 Iter 3100: Obj: 0.0009 GradNorm: 185.9355 Iter 3200: Obj: 0.0009 GradNorm: 263.0209 Iter 3300: Obj: 0.0009 GradNorm: 300.3132 Iter 3400: Obj: 0.0009 GradNorm: 231.4044 Iter 3500: Obj: 0.0009 GradNorm: 226.3184 Iter 3600: Obj: 0.0009 GradNorm: 211.4237 Iter 3700: Obj: 0.0009 GradNorm: 233.2981 Iter 3800: Obj: 0.0009 GradNorm: 299.0853 Iter 3900: Obj: 0.0009 GradNorm: 262.4271 Time : 4.292856931686401 Obj : 0.0008940778156521405 .. GENERATED FROM PYTHON SOURCE LINES 71-73 Compare Barycenters in both methods ----- .. GENERATED FROM PYTHON SOURCE LINES 73-88 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(len(barys)): if i == 0: pl.plot(x, barys[i], "g-*", label="Discrete MMOT") else: continue # pl.plot(x, barys[i], 'g-*') pl.plot(x, lp_bary, label="LP Barycenter") pl.plot(x, l2_bary, label="L2 Barycenter") pl.plot(x, a1, "b", label="Source distribution") pl.plot(x, a2, "r", label="Target distribution") pl.title("Monge Cost: Barycenters from LP Solver and dmmot solver") pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_002.png :alt: Monge Cost: Barycenters from LP Solver and dmmot solver :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 89-92 More than 2 distributions -------------------------------------------------- Generate 7 pseudorandom gaussian distributions with 50 bins. .. GENERATED FROM PYTHON SOURCE LINES 92-113 .. code-block:: Python n = 50 # nb bins d = 7 vecsize = n * d data = [] for i in range(d): m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) a = ot.datasets.make_1D_gauss(n, m=m, s=5) data.append(a) x = np.arange(n, dtype=np.float64) M = ot.utils.dist(x.reshape((n, 1)), metric="minkowski") A = np.vstack(data).T pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) pl.title("Distributions") pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_003.png :alt: Distributions :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/examples/others/plot_dmmot.py:111: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. pl.legend() .. GENERATED FROM PYTHON SOURCE LINES 114-118 Minimizing Distances Among Many Distributions --------------- The objective being minimized is different for both methods, so the objective values cannot be compared. .. GENERATED FROM PYTHON SOURCE LINES 118-132 .. code-block:: Python # Perform gradient descent optimization using the d-MMOT method. barys = ot.lp.dmmot_monge_1dgrid_optimize(A, niters=3000, lr_init=1e-4, lr_decay=0.997) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] # Compute 1D Wasserstein barycenter using the L2/LP method weights = ot.unif(d) l2_bary = A.dot(weights) lp_bary, bary_log = ot.lp.barycenter( A, M, weights, solver="interior-point", verbose=False, log=True ) .. rst-class:: sphx-glr-script-out .. code-block:: none Initial: Obj: 37.1964 GradNorm: 284.3413 Iter 0: Obj: 37.1964 GradNorm: 280.9858 Iter 100: Obj: 3.2628 GradNorm: 143.1922 Iter 200: Obj: 0.8687 GradNorm: 165.7830 Iter 300: Obj: 0.4563 GradNorm: 235.9619 Iter 400: Obj: 0.3789 GradNorm: 253.4285 Iter 500: Obj: 0.3109 GradNorm: 284.3413 Iter 600: Obj: 0.2467 GradNorm: 284.3413 Iter 700: Obj: 0.1794 GradNorm: 284.3413 Iter 800: Obj: 0.1023 GradNorm: 262.5719 Iter 900: Obj: 0.0815 GradNorm: 276.6912 Iter 1000: Obj: 0.0575 GradNorm: 258.2634 Iter 1100: Obj: 0.0450 GradNorm: 233.6365 Iter 1200: Obj: 0.0292 GradNorm: 218.8698 Iter 1300: Obj: 0.0264 GradNorm: 262.0572 Iter 1400: Obj: 0.0161 GradNorm: 212.5559 Iter 1500: Obj: 0.0132 GradNorm: 231.8016 Iter 1600: Obj: 0.0091 GradNorm: 193.5355 Iter 1700: Obj: 0.0069 GradNorm: 195.1973 Iter 1800: Obj: 0.0053 GradNorm: 186.4350 Iter 1900: Obj: 0.0043 GradNorm: 184.0869 Iter 2000: Obj: 0.0035 GradNorm: 195.0077 Iter 2100: Obj: 0.0028 GradNorm: 157.2132 Iter 2200: Obj: 0.0024 GradNorm: 169.3930 Iter 2300: Obj: 0.0022 GradNorm: 161.6787 Iter 2400: Obj: 0.0020 GradNorm: 147.3635 Iter 2500: Obj: 0.0018 GradNorm: 162.9417 Iter 2600: Obj: 0.0017 GradNorm: 144.6790 Iter 2700: Obj: 0.0016 GradNorm: 164.0792 Iter 2800: Obj: 0.0016 GradNorm: 121.3507 Iter 2900: Obj: 0.0015 GradNorm: 150.1533 .. GENERATED FROM PYTHON SOURCE LINES 133-135 Compare Barycenters in both methods --------- .. GENERATED FROM PYTHON SOURCE LINES 135-142 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) pl.plot(x, bary, "g-*", label="Discrete MMOT") pl.plot(x, l2_bary, "k", label="L2 Barycenter") pl.plot(x, lp_bary, "k-", label="LP Wasserstein") pl.title("Barycenters") pl.legend() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_004.png :alt: Barycenters :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 143-145 Compare with original distributions --------- .. GENERATED FROM PYTHON SOURCE LINES 145-160 .. code-block:: Python pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) for i in range(len(barys)): if i == 0: pl.plot(x, barys[i], "g-*", label="Discrete MMOT") else: continue # pl.plot(x, barys[i], 'g') pl.plot(x, l2_bary, "k^", label="L2") pl.plot(x, lp_bary, "o", color="grey", label="LP") pl.title("Barycenters") pl.legend() pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_dmmot_005.png :alt: Barycenters :srcset: /auto_examples/others/images/sphx_glr_plot_dmmot_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 29.862 seconds) .. _sphx_glr_download_auto_examples_others_plot_dmmot.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_dmmot.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_dmmot.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_dmmot.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_