.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/domain-adaptation/plot_otda_d2.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_domain-adaptation_plot_otda_d2.py: =================================================== OT for domain adaptation on empirical distributions =================================================== .. note:: Example added in release: 0.1.9. This example introduces a domain adaptation in a 2D setting. It explicit the problem of domain adaptation and introduces some optimal transport approaches to solve it. Quantities such as optimal couplings, greater coupling coefficients and transported samples are represented in order to give a visual understanding of what the transport methods are doing. .. GENERATED FROM PYTHON SOURCE LINES 18-30 .. code-block:: Python # Authors: Remi Flamary # Stanislas Chambon # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import matplotlib.pylab as pl import ot import ot.plot .. GENERATED FROM PYTHON SOURCE LINES 31-33 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 33-44 .. code-block:: Python n_samples_source = 150 n_samples_target = 150 Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source) Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target) # Cost matrix M = ot.dist(Xs, Xt, metric="sqeuclidean") .. GENERATED FROM PYTHON SOURCE LINES 45-47 Instantiate the different transport algorithms and fit them ----------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 47-66 .. code-block:: Python # EMD Transport ot_emd = ot.da.EMDTransport() ot_emd.fit(Xs=Xs, Xt=Xt) # Sinkhorn Transport ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) ot_sinkhorn.fit(Xs=Xs, Xt=Xt) # Sinkhorn Transport with Group lasso regularization ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0) ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) # transport source samples onto target samples transp_Xs_emd = ot_emd.transform(Xs=Xs) transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/ot/bregman/_sinkhorn.py:902: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`. warnings.warn( /home/circleci/project/ot/bregman/_sinkhorn.py:666: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 67-69 Fig 1 : plots source and target samples + matrix of pairwise distance --------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 69-93 .. code-block:: Python pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Source samples") pl.subplot(2, 2, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Target samples") pl.subplot(2, 2, 3) pl.imshow(M, interpolation="nearest") pl.xticks([]) pl.yticks([]) pl.title("Matrix of pairwise distances") pl.tight_layout() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_001.png :alt: Source samples, Target samples, Matrix of pairwise distances :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 94-96 Fig 2 : plots optimal couplings for the different methods --------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 96-142 .. code-block:: Python pl.figure(2, figsize=(10, 6)) pl.subplot(2, 3, 1) pl.imshow(ot_emd.coupling_, interpolation="nearest", cmap="gray_r") pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nEMDTransport") pl.subplot(2, 3, 2) pl.imshow(ot_sinkhorn.coupling_, interpolation="nearest", cmap="gray_r") pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nSinkhornTransport") pl.subplot(2, 3, 3) pl.imshow(ot_lpl1.coupling_, interpolation="nearest", cmap="gray_r") pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nSinkhornLpl1Transport") pl.subplot(2, 3, 4) ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[0.5, 0.5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.title("Main coupling coefficients\nEMDTransport") pl.subplot(2, 3, 5) ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[0.5, 0.5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.title("Main coupling coefficients\nSinkhornTransport") pl.subplot(2, 3, 6) ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[0.5, 0.5, 1]) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.title("Main coupling coefficients\nSinkhornLpl1Transport") pl.tight_layout() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_002.png :alt: Optimal coupling EMDTransport, Optimal coupling SinkhornTransport, Optimal coupling SinkhornLpl1Transport, Main coupling coefficients EMDTransport, Main coupling coefficients SinkhornTransport, Main coupling coefficients SinkhornLpl1Transport :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 143-145 Fig 3 : plot transported samples -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 145-193 .. code-block:: Python # display transported samples pl.figure(4, figsize=(10, 4)) pl.subplot(1, 3, 1) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) pl.scatter( transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.title("Transported samples\nEmdTransport") pl.legend(loc=0) pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) pl.scatter( transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.title("Transported samples\nSinkhornTransport") pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 3) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) pl.scatter( transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.title("Transported samples\nSinkhornLpl1Transport") pl.xticks([]) pl.yticks([]) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_003.png :alt: Transported samples EmdTransport, Transported samples SinkhornTransport, Transported samples SinkhornLpl1Transport :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_d2_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.581 seconds) .. _sphx_glr_download_auto_examples_domain-adaptation_plot_otda_d2.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_otda_d2.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_otda_d2.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_otda_d2.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_