.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_dual_ot_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_dual_ot_pytorch.py: ====================================================================== Dual OT solvers for entropic and quadratic regularized OT with Pytorch ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 9-22 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pyplot as pl import torch import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 23-25 Data generation --------------- .. GENERATED FROM PYTHON SOURCE LINES 25-43 .. code-block:: Python torch.manual_seed(1) n_source_samples = 100 n_target_samples = 100 theta = 2 * np.pi / 20 noise_level = 0.1 Xs, ys = ot.datasets.make_data_classif( 'gaussrot', n_source_samples, nz=noise_level) Xt, yt = ot.datasets.make_data_classif( 'gaussrot', n_target_samples, theta=theta, nz=noise_level) # one of the target mode changes its variance (no linear mapping) Xt[yt == 2] *= 3 Xt = Xt + 4 .. GENERATED FROM PYTHON SOURCE LINES 44-46 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 46-54 .. code-block:: Python pl.figure(1, (10, 5)) pl.clf() pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') pl.legend(loc=0) pl.title('Source and target distributions') .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_001.png :alt: Source and target distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Source and target distributions') .. GENERATED FROM PYTHON SOURCE LINES 55-57 Convert data to torch tensors ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 57-61 .. code-block:: Python xs = torch.tensor(Xs) xt = torch.tensor(Xt) .. GENERATED FROM PYTHON SOURCE LINES 62-64 Estimating dual variables for entropic OT ----------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-102 .. code-block:: Python u = torch.randn(n_source_samples, requires_grad=True) v = torch.randn(n_source_samples, requires_grad=True) reg = 0.5 optimizer = torch.optim.Adam([u, v], lr=1) # number of iteration n_iter = 200 losses = [] for i in range(n_iter): # generate noise samples # minus because we maximize te dual loss loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(2) pl.plot(losses) pl.grid() pl.title('Dual objective (negative)') pl.xlabel("Iterations") Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_002.png :alt: Dual objective (negative) :srcset: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter: 0, loss=0.20204949002247302 Iter: 10, loss=-19.545724121626748 Iter: 20, loss=-30.990987796849776 Iter: 30, loss=-35.01055298550316 Iter: 40, loss=-37.28067085437584 Iter: 50, loss=-38.56571908690233 Iter: 60, loss=-39.02123573970508 Iter: 70, loss=-39.1873453193487 Iter: 80, loss=-39.28292694482585 Iter: 90, loss=-39.32018928658712 Iter: 100, loss=-39.337288170061974 Iter: 110, loss=-39.34580003000355 Iter: 120, loss=-39.35143309374103 Iter: 130, loss=-39.35459003550905 Iter: 140, loss=-39.356434516488875 Iter: 150, loss=-39.35761692260319 Iter: 160, loss=-39.35839567723641 Iter: 170, loss=-39.35893056004739 Iter: 180, loss=-39.359316934366554 Iter: 190, loss=-39.35960481807062 .. GENERATED FROM PYTHON SOURCE LINES 103-105 Plot the estimated entropic OT plan ----------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 105-115 .. code-block:: Python pl.figure(3, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) pl.legend(loc=0) pl.title('Source and target distributions') .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_003.png :alt: Source and target distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Source and target distributions') .. GENERATED FROM PYTHON SOURCE LINES 116-118 Estimating dual variables for quadratic OT ------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 118-158 .. code-block:: Python u = torch.randn(n_source_samples, requires_grad=True) v = torch.randn(n_source_samples, requires_grad=True) reg = 0.01 optimizer = torch.optim.Adam([u, v], lr=1) # number of iteration n_iter = 200 losses = [] for i in range(n_iter): # generate noise samples # minus because we maximize te dual loss loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(4) pl.plot(losses) pl.grid() pl.title('Dual objective (negative)') pl.xlabel("Iterations") Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_004.png :alt: Dual objective (negative) :srcset: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter: 0, loss=-0.0018442196020623663 Iter: 10, loss=-19.432684976667407 Iter: 20, loss=-30.704734675782557 Iter: 30, loss=-34.653899852917014 Iter: 40, loss=-37.07142385298934 Iter: 50, loss=-38.31586367842154 Iter: 60, loss=-38.80630988859148 Iter: 70, loss=-39.03457107701675 Iter: 80, loss=-39.13937703389973 Iter: 90, loss=-39.18323238806764 Iter: 100, loss=-39.20488737118959 Iter: 110, loss=-39.21636041075263 Iter: 120, loss=-39.22280732082113 Iter: 130, loss=-39.22659265699188 Iter: 140, loss=-39.22891667846451 Iter: 150, loss=-39.23031025926827 Iter: 160, loss=-39.231130828396495 Iter: 170, loss=-39.23159207012832 Iter: 180, loss=-39.231838356108845 Iter: 190, loss=-39.231973223174855 .. GENERATED FROM PYTHON SOURCE LINES 159-161 Plot the estimated quadratic OT plan ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 161-169 .. code-block:: Python pl.figure(5, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) pl.legend(loc=0) pl.title('OT plan with quadratic regularization') .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_005.png :alt: OT plan with quadratic regularization :srcset: /auto_examples/backends/images/sphx_glr_plot_dual_ot_pytorch_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'OT plan with quadratic regularization') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 12.718 seconds) .. _sphx_glr_download_auto_examples_backends_plot_dual_ot_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_dual_ot_pytorch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_dual_ot_pytorch.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_