.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/domain-adaptation/plot_otda_d2.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_domain-adaptation_plot_otda_d2.py: =================================================== OT for domain adaptation on empirical distributions =================================================== This example introduces a domain adaptation in a 2D setting. It explicits the problem of domain adaptation and introduces some optimal transport approaches to solve it. Quantities such as optimal couplings, greater coupling coefficients and transported samples are represented in order to give a visual understanding of what the transport methods are doing. .. GENERATED FROM PYTHON SOURCE LINES 15-27 .. code-block:: Python # Authors: Remi Flamary # Stanislas Chambon # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import matplotlib.pylab as pl import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 28-30 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 30-41 .. code-block:: Python n_samples_source = 150 n_samples_target = 150 Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source) Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target) # Cost matrix M = ot.dist(Xs, Xt, metric='sqeuclidean') .. GENERATED FROM PYTHON SOURCE LINES 42-44 Instantiate the different transport algorithms and fit them ----------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 44-63 .. code-block:: Python # EMD Transport ot_emd = ot.da.EMDTransport() ot_emd.fit(Xs=Xs, Xt=Xt) # Sinkhorn Transport ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) ot_sinkhorn.fit(Xs=Xs, Xt=Xt) # Sinkhorn Transport with Group lasso regularization ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0) ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) # transport source samples onto target samples transp_Xs_emd = ot_emd.transform(Xs=Xs) transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/ot/bregman/_sinkhorn.py:531: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`. warnings.warn("Sinkhorn did not converge. You might want to " .. GENERATED FROM PYTHON SOURCE LINES 64-66 Fig 1 : plots source and target samples + matrix of pairwise distance --------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 66-90 .. code-block:: Python pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title('Source samples') pl.subplot(2, 2, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title('Target samples') pl.subplot(2, 2, 3) pl.imshow(M, interpolation='nearest') pl.xticks([]) pl.yticks([]) pl.title('Matrix of pairwise distances') pl.tight_layout() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_001.png :alt: Source samples, Target samples, Matrix of pairwise distances :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 91-93 Fig 2 : plots optimal couplings for the different methods --------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 93-139 .. code-block:: Python pl.figure(2, figsize=(10, 6)) pl.subplot(2, 3, 1) pl.imshow(ot_emd.coupling_, interpolation='nearest') pl.xticks([]) pl.yticks([]) pl.title('Optimal coupling\nEMDTransport') pl.subplot(2, 3, 2) pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest') pl.xticks([]) pl.yticks([]) pl.title('Optimal coupling\nSinkhornTransport') pl.subplot(2, 3, 3) pl.imshow(ot_lpl1.coupling_, interpolation='nearest') pl.xticks([]) pl.yticks([]) pl.title('Optimal coupling\nSinkhornLpl1Transport') pl.subplot(2, 3, 4) ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.xticks([]) pl.yticks([]) pl.title('Main coupling coefficients\nEMDTransport') pl.subplot(2, 3, 5) ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.xticks([]) pl.yticks([]) pl.title('Main coupling coefficients\nSinkhornTransport') pl.subplot(2, 3, 6) ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.xticks([]) pl.yticks([]) pl.title('Main coupling coefficients\nSinkhornLpl1Transport') pl.tight_layout() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_002.png :alt: Optimal coupling EMDTransport, Optimal coupling SinkhornTransport, Optimal coupling SinkhornLpl1Transport, Main coupling coefficients EMDTransport, Main coupling coefficients SinkhornTransport, Main coupling coefficients SinkhornLpl1Transport :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 140-142 Fig 3 : plot transported samples -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 142-175 .. code-block:: Python # display transported samples pl.figure(4, figsize=(10, 4)) pl.subplot(1, 3, 1) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples', alpha=0.5) pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, marker='+', label='Transp samples', s=30) pl.title('Transported samples\nEmdTransport') pl.legend(loc=0) pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples', alpha=0.5) pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, marker='+', label='Transp samples', s=30) pl.title('Transported samples\nSinkhornTransport') pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 3) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples', alpha=0.5) pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, marker='+', label='Transp samples', s=30) pl.title('Transported samples\nSinkhornLpl1Transport') pl.xticks([]) pl.yticks([]) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_003.png :alt: Transported samples EmdTransport, Transported samples SinkhornTransport, Transported samples SinkhornLpl1Transport :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.388 seconds) .. _sphx_glr_download_auto_examples_domain-adaptation_plot_otda_d2.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_otda_d2.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_otda_d2.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_