.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_stoch_continuous_ot_pytorch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_stoch_continuous_ot_pytorch.py: ====================================================================== Continuous OT plan estimation with Pytorch ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 9-23 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pyplot as pl import torch from torch import nn import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 24-26 Data generation --------------- .. GENERATED FROM PYTHON SOURCE LINES 26-42 .. code-block:: Python torch.manual_seed(42) np.random.seed(42) n_source_samples = 1000 n_target_samples = 1000 theta = 2 * np.pi / 20 noise_level = 0.1 Xs = np.random.randn(n_source_samples, 2) * 0.5 Xt = np.random.randn(n_target_samples, 2) * 2 # one of the target mode changes its variance (no linear mapping) Xt = Xt + 4 .. GENERATED FROM PYTHON SOURCE LINES 43-45 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 45-54 .. code-block:: Python nvisu = 300 pl.figure(1, (5, 5)) pl.clf() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", label="Source samples", alpha=0.5) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", label="Target samples", alpha=0.5) pl.legend(loc=0) ax_bounds = pl.axis() pl.title("Source and target distributions") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_001.png :alt: Source and target distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Source and target distributions') .. GENERATED FROM PYTHON SOURCE LINES 55-57 Convert data to torch tensors ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 57-61 .. code-block:: Python xs = torch.tensor(Xs) xt = torch.tensor(Xt) .. GENERATED FROM PYTHON SOURCE LINES 62-64 Estimating deep dual variables for entropic OT ---------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 64-126 .. code-block:: Python torch.manual_seed(42) # define the MLP model class Potential(torch.nn.Module): def __init__(self): super(Potential, self).__init__() self.fc1 = nn.Linear(2, 200) self.fc2 = nn.Linear(200, 1) self.relu = torch.nn.ReLU() # instead of Heaviside step fn def forward(self, x): output = self.fc1(x) output = self.relu(output) # instead of Heaviside step fn output = self.fc2(output) return output.ravel() u = Potential().double() v = Potential().double() reg = 1 optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=0.005) # number of iteration n_iter = 500 n_batch = 500 losses = [] for i in range(n_iter): # generate noise samples iperms = torch.randint(0, n_source_samples, (n_batch,)) ipermt = torch.randint(0, n_target_samples, (n_batch,)) xsi = xs[iperms] xti = xt[ipermt] # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(2) pl.plot(losses) pl.grid() pl.title("Dual objective (negative)") pl.xlabel("Iterations") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_002.png :alt: Dual objective (negative) :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter: 0, loss=0.257329928299894 Iter: 10, loss=-11.890456785604675 Iter: 20, loss=-15.58037947236615 Iter: 30, loss=-18.440996850749865 Iter: 40, loss=-22.12608610815788 Iter: 50, loss=-25.27590340578239 Iter: 60, loss=-27.268827591939186 Iter: 70, loss=-29.79159074243252 Iter: 80, loss=-31.63488731570214 Iter: 90, loss=-32.127228618478725 Iter: 100, loss=-32.696522621311644 Iter: 110, loss=-33.46949401889149 Iter: 120, loss=-32.64206913098603 Iter: 130, loss=-36.153816351532946 Iter: 140, loss=-34.28321242161009 Iter: 150, loss=-35.520585380642636 Iter: 160, loss=-35.67609658732353 Iter: 170, loss=-34.45865441165184 Iter: 180, loss=-34.43596310348252 Iter: 190, loss=-35.261945704106836 Iter: 200, loss=-34.01278968196741 Iter: 210, loss=-36.87401169938976 Iter: 220, loss=-35.128205680449874 Iter: 230, loss=-37.63722430960618 Iter: 240, loss=-35.659266219020246 Iter: 250, loss=-36.527425475361845 Iter: 260, loss=-36.126583681704034 Iter: 270, loss=-31.735871196038296 Iter: 280, loss=-36.157560505651844 Iter: 290, loss=-35.070647347170436 Iter: 300, loss=-34.27069736487666 Iter: 310, loss=-35.4032555710632 Iter: 320, loss=-35.7515193321185 Iter: 330, loss=-35.505787072896766 Iter: 340, loss=-35.833572391120526 Iter: 350, loss=-35.540097202879465 Iter: 360, loss=-33.547154649280394 Iter: 370, loss=-35.635978662861795 Iter: 380, loss=-35.85734724064163 Iter: 390, loss=-37.221950448334646 Iter: 400, loss=-36.545262551431136 Iter: 410, loss=-35.202882711135615 Iter: 420, loss=-35.14673091868494 Iter: 430, loss=-35.32427691564021 Iter: 440, loss=-36.51095472212391 Iter: 450, loss=-36.56664149963812 Iter: 460, loss=-37.5571464161218 Iter: 470, loss=-35.7012331965002 Iter: 480, loss=-36.75436312339336 Iter: 490, loss=-35.12920968279601 Text(0.5, 23.52222222222222, 'Iterations') .. GENERATED FROM PYTHON SOURCE LINES 127-129 Plot the density on target for a given source sample ---------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 129-222 .. code-block:: Python nv = 100 xl = np.linspace(ax_bounds[0], ax_bounds[1], nv) yl = np.linspace(ax_bounds[2], ax_bounds[3], nv) XX, YY = np.meshgrid(xl, yl) xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) wxg = np.exp(-((xg[:, 0] - 4) ** 2 + (xg[:, 1] - 4) ** 2) / (2 * 2)) wxg = wxg / np.sum(wxg) xg = torch.tensor(xg) wxg = torch.tensor(wxg) pl.figure(4, (12, 4)) pl.clf() pl.subplot(1, 3, 1) iv = 2 Gg = ot.stochastic.plan_dual_entropic( u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg ) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) pl.scatter( Xs[iv : iv + 1, 0], Xs[iv : iv + 1, 1], s=100, marker="+", label="Source sample", zorder=2, alpha=1, color="C0", ) pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() pl.title("Density of transported source sample") pl.subplot(1, 3, 2) iv = 3 Gg = ot.stochastic.plan_dual_entropic( u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg ) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) pl.scatter( Xs[iv : iv + 1, 0], Xs[iv : iv + 1, 1], s=100, marker="+", label="Source sample", zorder=2, alpha=1, color="C0", ) pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() pl.title("Density of transported source sample") pl.subplot(1, 3, 3) iv = 6 Gg = ot.stochastic.plan_dual_entropic( u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg ) Gg = Gg.reshape((nv, nv)).detach().numpy() pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) pl.scatter( Xs[iv : iv + 1, 0], Xs[iv : iv + 1, 1], s=100, marker="+", label="Source sample", zorder=2, alpha=1, color="C0", ) pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() pl.title("Density of transported source sample") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_003.png :alt: Density of transported source sample, Density of transported source sample, Density of transported source sample :srcset: /auto_examples/backends/images/sphx_glr_plot_stoch_continuous_ot_pytorch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:169: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:194: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) /home/circleci/project/examples/backends/plot_stoch_continuous_ot_pytorch.py:219: UserWarning: Legend does not support handles for QuadMesh instances. See: https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#implementing-a-custom-legend-handler pl.legend(loc=0) Text(0.5, 1.0, 'Density of transported source sample') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 43.026 seconds) .. _sphx_glr_download_auto_examples_backends_plot_stoch_continuous_ot_pytorch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_stoch_continuous_ot_pytorch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_stoch_continuous_ot_pytorch.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_stoch_continuous_ot_pytorch.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_