.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/others/plot_WDA.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_others_plot_WDA.py: ================================= Wasserstein Discriminant Analysis ================================= This example illustrate the use of WDA as proposed in [11]. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. .. GENERATED FROM PYTHON SOURCE LINES 14-27 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl from ot.dr import wda, fda .. GENERATED FROM PYTHON SOURCE LINES 28-30 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 32-56 .. code-block:: Python n = 1000 # nb samples in source and target datasets nz = 0.2 np.random.seed(1) # generate circle dataset t = np.random.rand(n) * 2 * np.pi ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 xs = np.concatenate( (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) t = np.random.rand(n) * 2 * np.pi yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 xt = np.concatenate( (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) nbnoise = 8 xs = np.hstack((xs, np.random.randn(n, nbnoise))) xt = np.hstack((xt, np.random.randn(n, nbnoise))) .. GENERATED FROM PYTHON SOURCE LINES 57-59 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 61-74 .. code-block:: Python pl.figure(1, figsize=(6.4, 3.5)) pl.subplot(1, 2, 1) pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') pl.subplot(1, 2, 2) pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Other dimensions') pl.tight_layout() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_WDA_001.png :alt: Discriminant dimensions, Other dimensions :srcset: /auto_examples/others/images/sphx_glr_plot_WDA_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 75-77 Compute Fisher Discriminant Analysis ------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 79-83 .. code-block:: Python p = 2 Pfda, projfda = fda(xs, ys, p) .. GENERATED FROM PYTHON SOURCE LINES 84-86 Compute Wasserstein Discriminant Analysis ----------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 88-100 .. code-block:: Python p = 2 reg = 1e0 k = 10 maxiter = 100 P0 = np.random.randn(xs.shape[1], p) P0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True)) Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0) .. rst-class:: sphx-glr-script-out .. code-block:: none Optimizing... Iteration Cost Gradient norm --------- ----------------------- -------------- 1 +8.3042776946697494e-01 5.65147154e-01 2 +4.4401037686381040e-01 2.16760501e-01 3 +4.2234351238819928e-01 1.30555049e-01 4 +4.2169879996364462e-01 1.39115407e-01 5 +4.1924746118060602e-01 1.25387848e-01 6 +4.1177409528990749e-01 6.70993539e-02 7 +4.0862213476139048e-01 3.52716830e-02 8 +4.0747229322240269e-01 3.34923131e-02 9 +4.0678766065261684e-01 2.74029183e-02 10 +4.0621337155459647e-01 2.03651803e-02 11 +4.0577080390746939e-01 2.59605592e-02 12 +4.0543140912472148e-01 3.28883715e-02 13 +4.0470236926310577e-01 1.47528039e-02 14 +4.0445628467498224e-01 5.03183254e-02 15 +4.0364189455866245e-01 3.31006504e-02 16 +4.0303977563823823e-01 1.39885352e-02 17 +4.0301476238242911e-01 2.17467624e-02 18 +4.0292344306414324e-01 1.79959907e-02 19 +4.0271888325518124e-01 6.94408237e-03 20 +4.0183214741002155e-01 1.98322994e-02 21 +3.9762636217090053e-01 1.03196875e-01 22 +3.8225627240876070e-01 1.36012863e-01 23 +3.0855506616050116e-01 1.92702943e-01 24 +2.8001027160864295e-01 2.01920255e-01 25 +2.3687486090807947e-01 9.01780640e-02 26 +2.3431203993360381e-01 7.23716793e-02 27 +2.3118645266923005e-01 2.90753137e-02 28 +2.3067593392325469e-01 1.02767925e-02 29 +2.3064856262240019e-01 8.07925279e-03 30 +2.3060699763593800e-01 1.95215754e-03 31 +2.3060442760754873e-01 2.77368118e-05 32 +2.3060442709529139e-01 5.34108449e-06 33 +2.3060442708435561e-01 3.52599061e-06 34 +2.3060442707674844e-01 1.07742368e-06 35 +2.3060442707600512e-01 2.36125504e-07 Terminated - min grad norm reached after 35 iterations, 8.68 seconds. .. GENERATED FROM PYTHON SOURCE LINES 101-103 Plot 2D projections ------------------- .. GENERATED FROM PYTHON SOURCE LINES 105-136 .. code-block:: Python xsp = projfda(xs) xtp = projfda(xt) xspw = projwda(xs) xtpw = projwda(xt) pl.figure(2) pl.subplot(2, 2, 1) pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples FDA') pl.subplot(2, 2, 2) pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples FDA') pl.subplot(2, 2, 3) pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples WDA') pl.subplot(2, 2, 4) pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples WDA') pl.tight_layout() pl.show() .. image-sg:: /auto_examples/others/images/sphx_glr_plot_WDA_002.png :alt: Projected training samples FDA, Projected test samples FDA, Projected training samples WDA, Projected test samples WDA :srcset: /auto_examples/others/images/sphx_glr_plot_WDA_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.425 seconds) .. _sphx_glr_download_auto_examples_others_plot_WDA.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_WDA.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_WDA.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_