.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_stochastic.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_stochastic.py: =================== Stochastic examples =================== This example is designed to show how to use the stochastic optimization algorithms for discrete and semi-continuous measures from the POT library. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) .. GENERATED FROM PYTHON SOURCE LINES 18-29 .. code-block:: Python # Author: Kilian Fatras # # License: MIT License import matplotlib.pylab as pl import numpy as np import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 30-38 Compute the Transportation Matrix for the Semi-Dual Problem ----------------------------------------------------------- Discrete case ````````````` Sample two discrete measures for the discrete case and compute their cost matrix c. .. GENERATED FROM PYTHON SOURCE LINES 38-52 .. code-block:: Python n_source = 7 n_target = 4 reg = 1 numItermax = 1000 a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) rng = np.random.RandomState(0) X_source = rng.randn(n_source, 2) Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) .. GENERATED FROM PYTHON SOURCE LINES 53-54 Call the "SAG" method to find the transportation matrix in the discrete case .. GENERATED FROM PYTHON SOURCE LINES 54-59 .. code-block:: Python method = "SAG" sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) print(sag_pi) .. rst-class:: sphx-glr-script-out .. code-block:: none [[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06] [1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03] [3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07] [2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04] [9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01] [2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01] [4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]] .. GENERATED FROM PYTHON SOURCE LINES 60-66 Semi-Continuous Case ```````````````````` Sample one general measure a, one discrete measures b for the semicontinuous case, the points where source and target measures are defined and compute the cost matrix. .. GENERATED FROM PYTHON SOURCE LINES 66-81 .. code-block:: Python n_source = 7 n_target = 4 reg = 1 numItermax = 1000 log = True a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) rng = np.random.RandomState(0) X_source = rng.randn(n_source, 2) Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Call the "ASGD" method to find the transportation matrix in the semicontinuous case. .. GENERATED FROM PYTHON SOURCE LINES 84-92 .. code-block:: Python method = "ASGD" asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic( a, b, M, reg, method, numItermax, log=log ) print(log_asgd["alpha"], log_asgd["beta"]) print(asgd_pi) .. rst-class:: sphx-glr-script-out .. code-block:: none [3.89418541 7.69191648 3.88798203 2.63066822 1.4605918 3.30128899 2.76039982] [-2.55838411 -2.42317354 -0.84802459 5.82958224] [[2.36658434e-02 1.00210228e-01 1.89765631e-02 4.50856086e-06] [1.19762224e-01 1.34039510e-02 1.48790516e-03 8.20306258e-03] [3.18880498e-03 7.40472984e-02 6.56209042e-02 1.35308774e-07] [2.34839063e-02 3.25971567e-02 8.63628461e-02 4.13233727e-04] [8.78057873e-03 7.27931720e-04 1.11939332e-02 1.22154699e-01] [1.95469897e-02 8.84579013e-04 1.95751817e-03 1.20468056e-01] [3.77891909e-02 2.62747246e-02 7.63341399e-02 2.45908742e-03]] .. GENERATED FROM PYTHON SOURCE LINES 93-94 Compare the results with the Sinkhorn algorithm .. GENERATED FROM PYTHON SOURCE LINES 94-99 .. code-block:: Python sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) .. rst-class:: sphx-glr-script-out .. code-block:: none [[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06] [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03] [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07] [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04] [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01] [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01] [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]] .. GENERATED FROM PYTHON SOURCE LINES 100-104 Plot Transportation Matrices ```````````````````````````` For SAG .. GENERATED FROM PYTHON SOURCE LINES 104-110 .. code-block:: Python pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sag_pi, "semi-dual : OT matrix SAG") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_stochastic_001.png :alt: plot stochastic :srcset: /auto_examples/others/images/sphx_glr_plot_stochastic_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 111-112 For ASGD .. GENERATED FROM PYTHON SOURCE LINES 112-118 .. code-block:: Python pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, asgd_pi, "semi-dual : OT matrix ASGD") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_stochastic_002.png :alt: plot stochastic :srcset: /auto_examples/others/images/sphx_glr_plot_stochastic_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 119-120 For Sinkhorn .. GENERATED FROM PYTHON SOURCE LINES 120-126 .. code-block:: Python pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_stochastic_003.png :alt: plot stochastic :srcset: /auto_examples/others/images/sphx_glr_plot_stochastic_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 127-135 Compute the Transportation Matrix for the Dual Problem ------------------------------------------------------ Semi-continuous case ```````````````````` Sample one general measure a, one discrete measures b for the semi-continuous case and compute the cost matrix c. .. GENERATED FROM PYTHON SOURCE LINES 135-152 .. code-block:: Python n_source = 7 n_target = 4 reg = 1 numItermax = 100000 lr = 0.1 batch_size = 3 log = True a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) rng = np.random.RandomState(0) X_source = rng.randn(n_source, 2) Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) .. GENERATED FROM PYTHON SOURCE LINES 153-155 Call the "SGD" dual method to find the transportation matrix in the semi-continuous case .. GENERATED FROM PYTHON SOURCE LINES 156-163 .. code-block:: Python sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic( a, b, M, reg, batch_size, numItermax, lr, log=log ) print(log_sgd["alpha"], log_sgd["beta"]) print(sgd_dual_pi) .. rst-class:: sphx-glr-script-out .. code-block:: none [0.91165603 2.77176512 1.06822819 0.02120131 0.6126745 1.82423613 0.11278947] [0.33918268 0.47789947 1.5719034 4.9335652 ] [[2.17381811e-02 9.23710741e-02 1.08114235e-02 9.32409360e-08] [1.58461053e-02 1.77974690e-03 1.22107416e-04 2.44369039e-05] [3.44684454e-03 8.03203669e-02 4.39946687e-02 3.29296540e-09] [3.13249603e-02 4.36337600e-02 7.14514820e-02 1.24103327e-05] [6.81827650e-02 5.67237379e-03 5.39135941e-02 2.13564527e-02] [8.09098220e-02 3.67435169e-03 5.02563876e-03 1.12269228e-02] [4.85201684e-02 3.38544380e-02 6.07907676e-02 7.10879755e-05]] .. GENERATED FROM PYTHON SOURCE LINES 164-168 Compare the results with the Sinkhorn algorithm ``````````````````````````````````````````````` Call the Sinkhorn algorithm from POT .. GENERATED FROM PYTHON SOURCE LINES 169-173 .. code-block:: Python sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) .. rst-class:: sphx-glr-script-out .. code-block:: none [[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06] [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03] [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07] [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04] [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01] [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01] [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]] .. GENERATED FROM PYTHON SOURCE LINES 174-178 Plot Transportation Matrices ```````````````````````````` For SGD .. GENERATED FROM PYTHON SOURCE LINES 178-184 .. code-block:: Python pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sgd_dual_pi, "dual : OT matrix SGD") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_stochastic_004.png :alt: plot stochastic :srcset: /auto_examples/others/images/sphx_glr_plot_stochastic_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 185-186 For Sinkhorn .. GENERATED FROM PYTHON SOURCE LINES 186-190 .. code-block:: Python pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn") pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_stochastic_005.png :alt: plot stochastic :srcset: /auto_examples/others/images/sphx_glr_plot_stochastic_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.197 seconds) .. _sphx_glr_download_auto_examples_others_plot_stochastic.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_stochastic.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_stochastic.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_stochastic.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_