.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_unmix_optim_torch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_unmix_optim_torch.py: ================================= Wasserstein unmixing with PyTorch ================================= In this example we estimate mixing parameters from distributions that minimize the Wasserstein distance. In other words we suppose that a target distribution :math:`\mu^t` can be expressed as a weighted sum of source distributions :math:`\mu^s_k` with the following model: .. math:: \mu^t = \sum_{k=1}^K w_k\mu^s_k where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the distribution simplex :math:`\Delta_K`. In order to estimate this weight vector we propose to optimize the Wasserstein distance between the model and the observed :math:`\mu^t` with respect to the vector. This leads to the following optimization problem: .. math:: \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right) This minimization is done in this example with a simple projected gradient descent in PyTorch. We use the automatic backend of POT that allows us to compute the Wasserstein distance with :any:`ot.emd2` with differentiable losses. .. GENERATED FROM PYTHON SOURCE LINES 31-44 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot import torch .. GENERATED FROM PYTHON SOURCE LINES 45-47 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 49-79 .. code-block:: Python nt = 100 nt1 = 10 # ns1 = 50 ns = 2 * ns1 rng = np.random.RandomState(2) xt = rng.randn(nt, 2) * 0.2 xt[:nt1, 0] += 1 xt[nt1:, 1] += 1 xs1 = rng.randn(ns1, 2) * 0.2 xs1[:, 0] += 1 xs2 = rng.randn(ns1, 2) * 0.2 xs2[:, 1] += 1 xs = np.concatenate((xs1, xs2)) # Sample reweighting matrix H H = np.zeros((ns, 2)) H[:ns1, 0] = 1 / ns1 H[ns1:, 1] = 1 / ns1 # each columns sums to 1 and has weights only for samples form the # corresponding source distribution M = ot.dist(xs, xt) .. GENERATED FROM PYTHON SOURCE LINES 80-82 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 84-93 .. code-block:: Python pl.figure(1) pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) pl.title('Sources and Target distributions') pl.legend() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_001.png :alt: Sources and Target distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 94-96 Optimization of the model wrt the Wasserstein distance ------------------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 99-136 .. code-block:: Python # convert numpy arrays to torch tensors H2 = torch.tensor(H) M2 = torch.tensor(M) # weights for the source distributions w = torch.tensor(ot.unif(2), requires_grad=True) # uniform weights for target b = torch.tensor(ot.unif(nt)) lr = 2e-3 # learning rate niter = 500 # number of iterations losses = [] # loss along the iterations # loss for the minimal Wasserstein estimator def get_loss(w): a = torch.mv(H2, w) # distribution reweighting return ot.emd2(a, b, M2) # squared Wasserstein 2 for i in range(niter): loss = get_loss(w) losses.append(float(loss)) loss.backward() with torch.no_grad(): w -= lr * w.grad # gradient step w[:] = ot.utils.proj_simplex(w) # projection on the simplex w.grad.zero_() .. GENERATED FROM PYTHON SOURCE LINES 137-139 Estimated weights and convergence of the objective -------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 139-149 .. code-block:: Python we = w.detach().numpy() print('Estimated mixture:', we) pl.figure(2) pl.semilogy(losses) pl.grid() pl.title('Wasserstein distance') pl.xlabel("Iterations") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_002.png :alt: Wasserstein distance :srcset: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Estimated mixture: [0.09980706 0.90019294] Text(0.5, 23.52222222222222, 'Iterations') .. GENERATED FROM PYTHON SOURCE LINES 150-152 Plotting the reweighted source distribution ------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 152-162 .. code-block:: Python pl.figure(3) # compute source weights ws = H.dot(we) pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) pl.title('Target and reweighted source distributions') pl.legend() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_003.png :alt: Target and reweighted source distributions :srcset: /auto_examples/backends/images/sphx_glr_plot_unmix_optim_torch_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.617 seconds) .. _sphx_glr_download_auto_examples_backends_plot_unmix_optim_torch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_unmix_optim_torch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_unmix_optim_torch.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_