.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/domain-adaptation/plot_otda_linear_mapping.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_domain-adaptation_plot_otda_linear_mapping.py: ============================ Linear OT mapping estimation ============================ .. GENERATED FROM PYTHON SOURCE LINES 10-17 .. code-block:: Python # Author: Remi Flamary # # License: MIT License # sphinx_gallery_thumbnail_number = 2 .. GENERATED FROM PYTHON SOURCE LINES 18-25 .. code-block:: Python import os from pathlib import Path import numpy as np from matplotlib import pyplot as plt import ot .. GENERATED FROM PYTHON SOURCE LINES 26-28 Generate data ------------- .. GENERATED FROM PYTHON SOURCE LINES 28-53 .. code-block:: Python n = 1000 d = 2 sigma = .1 rng = np.random.RandomState(42) # source samples angles = rng.rand(n, 1) * 2 * np.pi xs = np.concatenate((np.sin(angles), np.cos(angles)), axis=1) + sigma * rng.randn(n, 2) xs[:n // 2, 1] += 2 # target samples anglet = rng.rand(n, 1) * 2 * np.pi xt = np.concatenate((np.sin(anglet), np.cos(anglet)), axis=1) + sigma * rng.randn(n, 2) xt[:n // 2, 1] += 2 A = np.array([[1.5, .7], [.7, 1.5]]) b = np.array([[4, 2]]) xt = xt.dot(A) + b .. GENERATED FROM PYTHON SOURCE LINES 54-56 Plot data --------- .. GENERATED FROM PYTHON SOURCE LINES 56-64 .. code-block:: Python plt.figure(1, (5, 5)) plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') plt.legend(('Source', 'Target')) plt.title('Source and target distributions') plt.show() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_001.png :alt: Source and target distributions :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 65-67 Estimate linear mapping and transport ------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 67-79 .. code-block:: Python # Gaussian (linear) Monge mapping estimation Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt) xst = xs.dot(Ae) + be # Gaussian (linear) GW mapping estimation Agw, bgw = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(xs, xt) xstgw = xs.dot(Agw) + bgw .. GENERATED FROM PYTHON SOURCE LINES 80-82 Plot transported samples ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 82-99 .. code-block:: Python plt.figure(2, (10, 5)) plt.clf() plt.subplot(1, 2, 1) plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') plt.plot(xst[:, 0], xst[:, 1], '+') plt.legend(('Source', 'Target', 'Transp. Monge'), loc=0) plt.title('Transported samples with Monge') plt.subplot(1, 2, 2) plt.plot(xs[:, 0], xs[:, 1], '+') plt.plot(xt[:, 0], xt[:, 1], 'o') plt.plot(xstgw[:, 0], xstgw[:, 1], '+') plt.legend(('Source', 'Target', 'Transp. GW'), loc=0) plt.title('Transported samples with Gaussian GW') plt.show() .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_002.png :alt: Transported samples with Monge, Transported samples with Gaussian GW :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 100-102 Load image data --------------- .. GENERATED FROM PYTHON SOURCE LINES 102-129 .. code-block:: Python def im2mat(img): """Converts and image to matrix (one pixel per line)""" return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) def mat2im(X, shape): """Converts back a matrix to an image""" return X.reshape(shape) def minmax(img): return np.clip(img, 0, 1) # Loading images this_file = os.path.realpath('__file__') data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) .. GENERATED FROM PYTHON SOURCE LINES 130-132 Estimate mapping and adapt ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 132-156 .. code-block:: Python # Monge mapping mapping = ot.da.LinearTransport() mapping.fit(Xs=X1, Xt=X2) xst = mapping.transform(Xs=X1) xts = mapping.inverse_transform(Xt=X2) I1t = minmax(mat2im(xst, I1.shape)) I2t = minmax(mat2im(xts, I2.shape)) # gaussian GW mapping mapping = ot.da.LinearGWTransport() mapping.fit(Xs=X1, Xt=X2) xstgw = mapping.transform(Xs=X1) xtsgw = mapping.inverse_transform(Xt=X2) I1tgw = minmax(mat2im(xstgw, I1.shape)) I2tgw = minmax(mat2im(xtsgw, I2.shape)) .. GENERATED FROM PYTHON SOURCE LINES 160-162 Plot transformed images ----------------------- .. GENERATED FROM PYTHON SOURCE LINES 162-194 .. code-block:: Python plt.figure(3, figsize=(14, 7)) plt.subplot(2, 3, 1) plt.imshow(I1) plt.axis('off') plt.title('Im. 1') plt.subplot(2, 3, 4) plt.imshow(I2) plt.axis('off') plt.title('Im. 2') plt.subplot(2, 3, 2) plt.imshow(I1t) plt.axis('off') plt.title('Monge mapping Im. 1') plt.subplot(2, 3, 5) plt.imshow(I2t) plt.axis('off') plt.title('Inverse Monge mapping Im. 2') plt.subplot(2, 3, 3) plt.imshow(I1tgw) plt.axis('off') plt.title('Gaussian GW mapping Im. 1') plt.subplot(2, 3, 6) plt.imshow(I2tgw) plt.axis('off') plt.title('Inverse Gaussian GW mapping Im. 2') .. image-sg:: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_003.png :alt: Im. 1, Im. 2, Monge mapping Im. 1, Inverse Monge mapping Im. 2, Gaussian GW mapping Im. 1, Inverse Gaussian GW mapping Im. 2 :srcset: /auto_examples/domain-adaptation/images/sphx_glr_plot_otda_linear_mapping_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Inverse Gaussian GW mapping Im. 2') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.514 seconds) .. _sphx_glr_download_auto_examples_domain-adaptation_plot_otda_linear_mapping.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_otda_linear_mapping.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_otda_linear_mapping.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_