.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_lowrank_sinkhorn.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_lowrank_sinkhorn.py: ======================================== Low rank Sinkhorn ======================================== This example illustrates the computation of Low Rank Sinkhorn [26]. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. .. GENERATED FROM PYTHON SOURCE LINES 12-24 .. code-block:: Python # Author: Laurène David # # License: MIT License # # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot.plot from ot.datasets import make_1D_gauss as gauss .. GENERATED FROM PYTHON SOURCE LINES 25-27 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 29-45 .. code-block:: Python n = 100 m = 120 # Gaussian distribution a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(n, m=int(5 * n / 6), s=15 / np.sqrt(2)) a = a / np.sum(a) b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(m, m=int(m / 2), s=35 / np.sqrt(2)) b = b / np.sum(b) # Source and target distribution X = np.arange(n).reshape(-1, 1) Y = np.arange(m).reshape(-1, 1) .. GENERATED FROM PYTHON SOURCE LINES 46-48 Solve Low rank sinkhorn ------------ .. GENERATED FROM PYTHON SOURCE LINES 50-51 Solve low rank sinkhorn .. GENERATED FROM PYTHON SOURCE LINES 51-57 .. code-block:: Python Q, R, g, log = ot.lowrank_sinkhorn(X, Y, a, b, rank=10, init="random", gamma_init="rescale", rescale_cost=True, warn=False, log=True) P = log["lazy_plan"][:] ot.plot.plot1D_mat(a, b, P, 'OT matrix Low rank') .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_001.png :alt: OT matrix Low rank :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 58-61 Sinkhorn vs Low Rank Sinkhorn ----------------------- Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks. .. GENERATED FROM PYTHON SOURCE LINES 63-76 .. code-block:: Python # Compute cost matrix for sinkhorn OT M = ot.dist(X, Y) M = M / np.max(M) # Solve sinkhorn with different regularizations using ot.solve list_reg = [0.05, 0.005, 0.001] list_P_Sin = [] for reg in list_reg: P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan list_P_Sin.append(P) .. GENERATED FROM PYTHON SOURCE LINES 77-88 .. code-block:: Python # Solve low rank sinkhorn with different ranks using ot.solve_sample list_rank = [3, 10, 50] list_P_LR = [] for rank in list_rank: P = ot.solve_sample(X, Y, a, b, method='lowrank', rank=rank).plan P = P[:] list_P_LR.append(P) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/ot/lowrank.py:300: UserWarning: Dykstra did not converge. You might want to increase the number of iterations `numItermax` warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 89-110 .. code-block:: Python # Plot sinkhorn vs low rank sinkhorn pl.figure(1, figsize=(10, 4)) pl.subplot(1, 3, 1) pl.imshow(list_P_Sin[0], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.05)') pl.subplot(1, 3, 2) pl.imshow(list_P_Sin[1], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.005)') pl.subplot(1, 3, 3) pl.imshow(list_P_Sin[2], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.001)') pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_002.png :alt: Sinkhorn (reg=0.05), Sinkhorn (reg=0.005), Sinkhorn (reg=0.001) :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 111-130 .. code-block:: Python pl.figure(2, figsize=(10, 4)) pl.subplot(1, 3, 1) pl.imshow(list_P_LR[0], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=3)') pl.subplot(1, 3, 2) pl.imshow(list_P_LR[1], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=10)') pl.subplot(1, 3, 3) pl.imshow(list_P_LR[2], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=50)') pl.tight_layout() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_003.png :alt: Low rank (rank=3), Low rank (rank=10), Low rank (rank=50) :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_sinkhorn_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 18.019 seconds) .. _sphx_glr_download_auto_examples_others_plot_lowrank_sinkhorn.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_lowrank_sinkhorn.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_lowrank_sinkhorn.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_