.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/domain-adaptation/plot_otda_semi_supervised.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_domain-adaptation_plot_otda_semi_supervised.py: ============================================ OTDA unsupervised vs semi-supervised setting ============================================ .. note:: Example added in release: 0.1.9. This example introduces a semi supervised domain adaptation in a 2D setting. It explicit the problem of semi supervised domain adaptation and introduces some optimal transport approaches to solve it. Quantities such as optimal couplings, greater coupling coefficients and transported samples are represented in order to give a visual understanding of what the transport methods are doing. .. GENERATED FROM PYTHON SOURCE LINES 18-30 .. code-block:: Python # Authors: Remi Flamary # Stanislas Chambon # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import matplotlib.pylab as pl import ot .. GENERATED FROM PYTHON SOURCE LINES 31-33 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 33-41 .. code-block:: Python n_samples_source = 150 n_samples_target = 150 Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source) Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target) .. GENERATED FROM PYTHON SOURCE LINES 42-44 Transport source samples onto target samples -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 44-69 .. code-block:: Python # unsupervised domain adaptation ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1) ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt) transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs) # semi-supervised domain adaptation ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1) ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt) transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs) # semi supervised DA uses available labeled target samples to modify the cost # matrix involved in the OT problem. The cost of transporting a source sample # of class A onto a target sample of class B != A is set to infinite, or a # very large value # note that in the present case we consider that all the target samples are # labeled. For daily applications, some target sample might not have labels, # in this case the element of yt corresponding to these samples should be # filled with -1. # Warning: we recall that -1 cannot be used as a class label .. GENERATED FROM PYTHON SOURCE LINES 70-72 Fig 1 : plots source and target samples + matrix of pairwise distance --------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 72-106 .. code-block:: Python pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Source samples") pl.subplot(2, 2, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Target samples") pl.subplot(2, 2, 3) pl.imshow(ot_sinkhorn_un.cost_, interpolation="nearest") pl.xticks([]) pl.yticks([]) pl.title("Cost matrix - unsupervised DA") pl.subplot(2, 2, 4) pl.imshow(ot_sinkhorn_semi.cost_, interpolation="nearest") pl.xticks([]) pl.yticks([]) pl.title("Cost matrix - semi-supervised DA") pl.tight_layout() # the optimal coupling in the semi-supervised DA case will exhibit " shape # similar" to the cost matrix, (block diagonal matrix) .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_001.png :alt: Source samples, Target samples, Cost matrix - unsupervised DA, Cost matrix - semi-supervised DA :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 107-109 Fig 2 : plots optimal couplings for the different methods --------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 109-127 .. code-block:: Python pl.figure(2, figsize=(8, 4)) pl.subplot(1, 2, 1) pl.imshow(ot_sinkhorn_un.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nUnsupervised DA") pl.subplot(1, 2, 2) pl.imshow(ot_sinkhorn_semi.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nSemi-supervised DA") pl.tight_layout() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_002.png :alt: Optimal coupling Unsupervised DA, Optimal coupling Semi-supervised DA :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 128-130 Fig 3 : plot transported samples -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 130-164 .. code-block:: Python # display transported samples pl.figure(4, figsize=(8, 4)) pl.subplot(1, 2, 1) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) pl.scatter( transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.title("Transported samples\nEmdTransport") pl.legend(loc=0) pl.xticks([]) pl.yticks([]) pl.subplot(1, 2, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) pl.scatter( transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.title("Transported samples\nSinkhornTransport") pl.xticks([]) pl.yticks([]) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_003.png :alt: Transported samples EmdTransport, Transported samples SinkhornTransport :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_semi_supervised_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.607 seconds) .. _sphx_glr_download_auto_examples_domain-adaptation_plot_otda_semi_supervised.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_otda_semi_supervised.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_otda_semi_supervised.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_otda_semi_supervised.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_