.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_ssw_unif_torch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_ssw_unif_torch.py: ================================================ Spherical Sliced-Wasserstein Embedding on Sphere ================================================ Here, we aim at transforming samples into a uniform distribution on the sphere by minimizing SSW: .. math:: \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) where :math:`\nu=\mathrm{Unif}(S^1)`. .. GENERATED FROM PYTHON SOURCE LINES 16-32 .. code-block:: Python # Author: Clément Bonet # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pyplot as pl import matplotlib.animation as animation import torch import torch.nn.functional as F import ot .. GENERATED FROM PYTHON SOURCE LINES 33-35 Data generation --------------- .. GENERATED FROM PYTHON SOURCE LINES 35-43 .. code-block:: Python torch.manual_seed(1) N = 500 x0 = torch.rand(N, 3) x0 = F.normalize(x0, dim=-1) .. GENERATED FROM PYTHON SOURCE LINES 44-46 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 46-68 .. code-block:: Python def plot_sphere(ax): xlist = np.linspace(-1.0, 1.0, 50) ylist = np.linspace(-1.0, 1.0, 50) r = np.linspace(1.0, 1.0, 50) X, Y = np.meshgrid(xlist, ylist) Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0)) ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half # plot the distributions pl.figure(1) ax = pl.axes(projection='3d') plot_sphere(ax) ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) ax.set_title('Data distribution') ax.legend() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_001.png :alt: Data distribution :srcset: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 69-71 Gradient descent ---------------- .. GENERATED FROM PYTHON SOURCE LINES 71-101 .. code-block:: Python x = x0.clone() x.requires_grad_(True) n_iter = 100 lr = 150 losses = [] xvisu = torch.zeros(n_iter, N, 3) for i in range(n_iter): sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) grad_x = torch.autograd.grad(sw, x)[0] x = x - lr * grad_x / np.sqrt(i / 10 + 1) x = F.normalize(x, p=2, dim=1) losses.append(sw.item()) xvisu[i, :, :] = x.detach().clone() if i % 100 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) pl.figure(1) pl.semilogy(losses) pl.grid() pl.title('SSW') pl.xlabel("Iterations") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_002.png :alt: SSW :srcset: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter: 0, loss=0.21160438656806946 Text(0.5, 23.52222222222222, 'Iterations') .. GENERATED FROM PYTHON SOURCE LINES 102-104 Plot trajectories of generated samples along iterations ------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 104-120 .. code-block:: Python ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80] fig = pl.figure(3, (10, 10)) for i in range(9): # pl.subplot(3, 3, i + 1) # ax = pl.axes(projection='3d') ax = fig.add_subplot(3, 3, i + 1, projection='3d') plot_sphere(ax) ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) ax.set_title('Iter. {}'.format(ivisu[i])) #ax.axis("off") if i == 0: ax.legend() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_003.png :alt: Iter. 0, Iter. 10, Iter. 20, Iter. 30, Iter. 40, Iter. 50, Iter. 60, Iter. 70, Iter. 80 :srcset: /auto_examples/backends/images/sphx_glr_plot_ssw_unif_torch_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 121-123 Animate trajectories of generated samples along iteration --------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 123-153 .. code-block:: Python pl.figure(4, (8, 8)) def _update_plot(i): i = 3 * i pl.clf() ax = pl.axes(projection='3d') plot_sphere(ax) ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) ax.axis("off") ax.set_xlim((-1.5, 1.5)) ax.set_ylim((-1.5, 1.5)) ax.set_title('Iter. {}'.format(i)) return 1 print(xvisu.shape) i = 0 ax = pl.axes(projection='3d') plot_sphere(ax) ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) ax.axis("off") ax.set_xlim((-1.5, 1.5)) ax.set_ylim((-1.5, 1.5)) ax.set_title('Iter. {}'.format(ivisu[i])) ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000) .. container:: sphx-glr-animation .. raw:: html
.. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([100, 500, 3]) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 43.232 seconds) .. _sphx_glr_download_auto_examples_backends_plot_ssw_unif_torch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_ssw_unif_torch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_ssw_unif_torch.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_