Wasserstein Discriminant Analysis

This example illustrate the use of WDA as proposed in [11].

[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis.

# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl

from ot.dr import wda, fda

Generate data

n = 1000  # nb samples in source and target datasets
nz = 0.2

np.random.seed(1)

# generate circle dataset
t = np.random.rand(n) * 2 * np.pi
ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
xs = np.concatenate(
    (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)

t = np.random.rand(n) * 2 * np.pi
yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
xt = np.concatenate(
    (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)

nbnoise = 8

xs = np.hstack((xs, np.random.randn(n, nbnoise)))
xt = np.hstack((xt, np.random.randn(n, nbnoise)))

Plot data

pl.figure(1, figsize=(6.4, 3.5))

pl.subplot(1, 2, 1)
pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')

pl.subplot(1, 2, 2)
pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Other dimensions')
pl.tight_layout()
Discriminant dimensions, Other dimensions

Compute Fisher Discriminant Analysis

p = 2

Pfda, projfda = fda(xs, ys, p)

Compute Wasserstein Discriminant Analysis

p = 2
reg = 1e0
k = 10
maxiter = 100

P0 = np.random.randn(xs.shape[1], p)

P0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True))

Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0)

Out:

 iter              cost val         grad. norm
    1   +8.3042776946697494e-01 5.65147154e-01
    2   +4.4401037686381029e-01 2.16760501e-01
    3   +4.2234351238819945e-01 1.30555049e-01
    4   +4.2169879996364401e-01 1.39115407e-01
    5   +4.1924746118060263e-01 1.25387848e-01
    6   +4.1177409528989911e-01 6.70993539e-02
    7   +4.0862213476139281e-01 3.52716830e-02
    8   +4.0747229322239997e-01 3.34923131e-02
    9   +4.0678766065263389e-01 2.74029183e-02
   10   +4.0621337155458270e-01 2.03651803e-02
   11   +4.0577080390746895e-01 2.59605592e-02
   12   +4.0543140912447923e-01 3.28883715e-02
   13   +4.0470236926304071e-01 1.47528039e-02
   14   +4.0445628469363082e-01 5.03183257e-02
   15   +4.0364189462420308e-01 3.31006521e-02
   16   +4.0303977558223164e-01 1.39885302e-02
   17   +4.0301476264735470e-01 2.17467792e-02
   18   +4.0292344437698596e-01 1.79960568e-02
   19   +4.0271888410932871e-01 6.94405752e-03
   20   +4.0183209912048512e-01 1.98305321e-02
   21   +3.9762292385877218e-01 1.03204023e-01
   22   +3.8223876599179624e-01 1.36080567e-01
   23   +3.0850475016688772e-01 1.92700755e-01
   24   +2.8013381841750584e-01 2.02121729e-01
   25   +2.3660197531968816e-01 8.83012243e-02
   26   +2.3506634015833044e-01 7.91650282e-02
   27   +2.3103093050518325e-01 2.46231369e-02
   28   +2.3062144241991714e-01 5.00497774e-03
   29   +2.3060877912413660e-01 2.53579804e-03
   30   +2.3060533093932806e-01 1.15685246e-03
   31   +2.3060517207638889e-01 1.05068242e-03
   32   +2.3060467214615721e-01 6.02549410e-04
   33   +2.3060442721731178e-01 1.44081829e-05
   34   +2.3060442714409929e-01 1.00254196e-05
   35   +2.3060442707640050e-01 8.02599057e-07
Terminated - min grad norm reached after 35 iterations, 10.54 seconds.

Plot 2D projections

xsp = projfda(xs)
xtp = projfda(xt)

xspw = projwda(xs)
xtpw = projwda(xt)

pl.figure(2)

pl.subplot(2, 2, 1)
pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected training samples FDA')

pl.subplot(2, 2, 2)
pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected test samples FDA')

pl.subplot(2, 2, 3)
pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected training samples WDA')

pl.subplot(2, 2, 4)
pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected test samples WDA')
pl.tight_layout()

pl.show()
Projected training samples FDA, Projected test samples FDA, Projected training samples WDA, Projected test samples WDA

Total running time of the script: ( 0 minutes 11.093 seconds)

Gallery generated by Sphinx-Gallery