.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_lowrank_GW.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_lowrank_GW.py: ======================================== Low rank Gromov-Wasterstein between samples ======================================== Comparison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] on two curves in 2D and 3D, both sampled with 200 points. The squared Euclidean distance is considered as the ground cost for both samples. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". In International Conference on Machine Learning (ICML), 2022. .. GENERATED FROM PYTHON SOURCE LINES 16-23 .. code-block:: Python # Author: Laurène David # # License: MIT License # # sphinx_gallery_thumbnail_number = 3 .. GENERATED FROM PYTHON SOURCE LINES 24-29 .. code-block:: Python import numpy as np import matplotlib.pylab as pl import ot.plot import time .. GENERATED FROM PYTHON SOURCE LINES 30-32 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 34-48 .. code-block:: Python n_samples = 200 # Generate 2D and 3D curves theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples) z = np.linspace(1, 2, n_samples) r = z**2 + 1 x = r * np.sin(theta) y = r * np.cos(theta) # Source and target distribution X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1) Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1) .. GENERATED FROM PYTHON SOURCE LINES 49-51 Plot data ------------ .. GENERATED FROM PYTHON SOURCE LINES 53-54 Plot the source and target samples .. GENERATED FROM PYTHON SOURCE LINES 54-75 .. code-block:: Python fig = pl.figure(1, figsize=(10, 4)) ax = fig.add_subplot(121) ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6) ax.tick_params( left=False, right=False, labelleft=False, labelbottom=False, bottom=False ) ax.set_title("2D curve (source)") ax2 = fig.add_subplot(122, projection="3d") ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c="red", linewidth=6) ax2.tick_params( left=False, right=False, labelleft=False, labelbottom=False, bottom=False ) ax2.view_init(15, -50) ax2.set_title("3D curve (target)") pl.tight_layout() pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_001.png :alt: 2D curve (source), 3D curve (target) :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 76-78 Entropic Gromov-Wasserstein ------------ .. GENERATED FROM PYTHON SOURCE LINES 80-113 .. code-block:: Python # Compute cost matrices C1 = ot.dist(X, X, metric="sqeuclidean") C2 = ot.dist(Y, Y, metric="sqeuclidean") # Scale cost matrices r1 = C1.max() r2 = C2.max() C1 = C1 / r1 C2 = C2 / r2 # Solve entropic gw reg = 5 * 1e-3 start = time.time() gw, log = ot.gromov.entropic_gromov_wasserstein( C1, C2, tol=1e-3, epsilon=reg, log=True, verbose=False ) end = time.time() time_entropic = end - start entropic_gw_loss = np.round(log["gw_dist"], 3) # Plot entropic gw pl.figure(2) pl.imshow(gw, interpolation="nearest", aspect="auto") pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss)) pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_002.png :alt: Entropic Gromov-Wasserstein (loss=0.037) :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 114-117 Low rank squared euclidean cost matrices ------------ %% .. GENERATED FROM PYTHON SOURCE LINES 117-127 .. code-block:: Python # Compute the low rank sqeuclidean cost decompositions A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False) # Scale the low rank cost matrices A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1) B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2) .. GENERATED FROM PYTHON SOURCE LINES 128-131 Low rank Gromov-Wasserstein ------------ %% .. GENERATED FROM PYTHON SOURCE LINES 131-164 .. code-block:: Python # Solve low rank gromov-wasserstein with different ranks list_rank = [10, 50] list_P_GW = [] list_loss_GW = [] list_time_GW = [] for rank in list_rank: start = time.time() Q, R, g, log = ot.lowrank_gromov_wasserstein_samples( X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2), cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6, ) end = time.time() P = log["lazy_plan"][:] loss = log["value"] list_P_GW.append(P) list_loss_GW.append(np.round(loss, 3)) list_time_GW.append(end - start) .. GENERATED FROM PYTHON SOURCE LINES 165-166 Plot low rank GW with different ranks .. GENERATED FROM PYTHON SOURCE LINES 166-180 .. code-block:: Python pl.figure(3, figsize=(10, 4)) pl.subplot(1, 2, 1) pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto") pl.title("Low rank GW (rank=10, loss={})".format(list_loss_GW[0])) pl.subplot(1, 2, 2) pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto") pl.title("Low rank GW (rank=50, loss={})".format(list_loss_GW[1])) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_003.png :alt: Low rank GW (rank=10, loss=0.037), Low rank GW (rank=50, loss=0.036) :srcset: /auto_examples/others/images/sphx_glr_plot_lowrank_GW_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 181-182 Compare computation time between entropic GW and low rank GW .. GENERATED FROM PYTHON SOURCE LINES 182-185 .. code-block:: Python print("Entropic GW: {:.2f}s".format(time_entropic)) print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0])) print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1])) .. rst-class:: sphx-glr-script-out .. code-block:: none Entropic GW: 0.34s Low rank GW (rank=10): 0.34s Low rank GW (rank=50): 0.42s .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.710 seconds) .. _sphx_glr_download_auto_examples_others_plot_lowrank_GW.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_lowrank_GW.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_lowrank_GW.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_lowrank_GW.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_