.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_stoch_continuous_ot_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_stoch_continuous_ot_pytorch.py: ====================================================================== Continuous OT plan estimation with Pytorch ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 9-23 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pyplot as pl import torch from torch import nn import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 24-26 Data generation --------------- .. GENERATED FROM PYTHON SOURCE LINES 26-42 .. code-block:: Python torch.manual_seed(42) np.random.seed(42) n_source_samples = 1000 n_target_samples = 1000 theta = 2 * np.pi / 20 noise_level = 0.1 Xs = np.random.randn(n_source_samples, 2) * 0.5 Xt = np.random.randn(n_target_samples, 2) * 2 # one of the target mode changes its variance (no linear mapping) Xt = Xt + 4 .. GENERATED FROM PYTHON SOURCE LINES 43-45 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 45-54 .. code-block:: Python nvisu = 300 pl.figure(1, (5, 5)) pl.clf() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5) pl.legend(loc=0) ax_bounds = pl.axis() pl.title('Source and target distributions') .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_001.png :alt: Source and target distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Source and target distributions') .. GENERATED FROM PYTHON SOURCE LINES 55-57 Convert data to torch tensors ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 57-61 .. code-block:: Python xs = torch.tensor(Xs) xt = torch.tensor(Xt) .. GENERATED FROM PYTHON SOURCE LINES 62-64 Estimating deep dual variables for entropic OT ---------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-127 .. code-block:: Python torch.manual_seed(42) # define the MLP model class Potential(torch.nn.Module): def __init__(self): super(Potential, self).__init__() self.fc1 = nn.Linear(2, 200) self.fc2 = nn.Linear(200, 1) self.relu = torch.nn.ReLU() # instead of Heaviside step fn def forward(self, x): output = self.fc1(x) output = self.relu(output) # instead of Heaviside step fn output = self.fc2(output) return output.ravel() u = Potential().double() v = Potential().double() reg = 1 optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) # number of iteration n_iter = 500 n_batch = 500 losses = [] for i in range(n_iter): # generate noise samples iperms = torch.randint(0, n_source_samples, (n_batch,)) ipermt = torch.randint(0, n_target_samples, (n_batch,)) xsi = xs[iperms] xti = xt[ipermt] # minus because we maximize te dual loss loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(2) pl.plot(losses) pl.grid() pl.title('Dual objective (negative)') pl.xlabel("Iterations") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_002.png :alt: Dual objective (negative) :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter: 0, loss=0.257329928299894 Iter: 10, loss=-11.890456785604673 Iter: 20, loss=-15.58037947236615 Iter: 30, loss=-18.440996850749865 Iter: 40, loss=-22.12608610815788 Iter: 50, loss=-25.275903405782387 Iter: 60, loss=-27.268827591939186 Iter: 70, loss=-29.79159074243252 Iter: 80, loss=-31.63488731570214 Iter: 90, loss=-32.12722861847872 Iter: 100, loss=-32.696522621311644 Iter: 110, loss=-33.46949401889149 Iter: 120, loss=-32.64206913098603 Iter: 130, loss=-36.15381635153295 Iter: 140, loss=-34.28321242161009 Iter: 150, loss=-35.520585380642636 Iter: 160, loss=-35.676096587323535 Iter: 170, loss=-34.45865441165184 Iter: 180, loss=-34.43596310348253 Iter: 190, loss=-35.26194570410683 Iter: 200, loss=-34.01278968196742 Iter: 210, loss=-36.87401169938976 Iter: 220, loss=-35.12820568044987 Iter: 230, loss=-37.6372243096062 Iter: 240, loss=-35.65926621902025 Iter: 250, loss=-36.527425475361845 Iter: 260, loss=-36.126583681704034 Iter: 270, loss=-31.735871196038286 Iter: 280, loss=-36.157560505651844 Iter: 290, loss=-35.070647347170436 Iter: 300, loss=-34.27069736487665 Iter: 310, loss=-35.40325557106319 Iter: 320, loss=-35.75151933211852 Iter: 330, loss=-35.50578707289677 Iter: 340, loss=-35.833572391120526 Iter: 350, loss=-35.54009720287946 Iter: 360, loss=-33.54715464928039 Iter: 370, loss=-35.63597866286179 Iter: 380, loss=-35.85734724064162 Iter: 390, loss=-37.22195044833466 Iter: 400, loss=-36.54526255143114 Iter: 410, loss=-35.20288271113561 Iter: 420, loss=-35.14673091868496 Iter: 430, loss=-35.32427691564021 Iter: 440, loss=-36.51095472212393 Iter: 450, loss=-36.56664149963812 Iter: 460, loss=-37.557146416121796 Iter: 470, loss=-35.7012331965002 Iter: 480, loss=-36.754363123393354 Iter: 490, loss=-35.129209682796 Text(0.5, 23.52222222222222, 'Iterations') .. GENERATED FROM PYTHON SOURCE LINES 128-130 Plot the density on target for a given source sample ---------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 130-190 .. code-block:: Python nv = 100 xl = np.linspace(ax_bounds[0], ax_bounds[1], nv) yl = np.linspace(ax_bounds[2], ax_bounds[3], nv) XX, YY = np.meshgrid(xl, yl) xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2)) wxg = wxg / np.sum(wxg) xg = torch.tensor(xg) wxg = torch.tensor(wxg) pl.figure(4, (12, 4)) pl.clf() pl.subplot(1, 3, 1) iv = 2 Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') pl.legend(loc=0) ax_bounds = pl.axis() pl.title('Density of transported source sample') pl.subplot(1, 3, 2) iv = 3 Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') pl.legend(loc=0) ax_bounds = pl.axis() pl.title('Density of transported source sample') pl.subplot(1, 3, 3) iv = 6 Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') pl.legend(loc=0) ax_bounds = pl.axis() pl.title('Density of transported source sample') .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_003.png :alt: Density of transported source sample, Density of transported source sample, Density of transported source sample :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:159: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:173: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:187: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) Text(0.5, 1.0, 'Density of transported source sample') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 49.531 seconds) .. _sphx_glr_download_auto_examples_backends_plot_stoch_continuous_ot_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_stoch_continuous_ot_pytorch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_stoch_continuous_ot_pytorch.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_