ot.utils

Various useful functions

Functions

ot.utils.check_params(**kwargs)[source]

check_params: check whether some parameters are missing

ot.utils.check_random_state(seed)[source]

Turn seed into a np.random.RandomState instance

Parameters

seed (None | int | instance of RandomState) – If seed is None, return the RandomState singleton used by np.random. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Otherwise raise ValueError.

ot.utils.clean_zeros(a, b, M)[source]

Remove all components with zeros weights in \(\mathbf{a}\) and \(\mathbf{b}\)

ot.utils.cost_normalization(C, norm=None)[source]

Apply normalization to the loss matrix

Parameters
  • C (ndarray, shape (n1, n2)) – The cost matrix to normalize.

  • norm (str) – Type of normalization from ‘median’, ‘max’, ‘log’, ‘loglog’. Any other value do not normalize.

Returns

C – The input cost matrix normalized according to given norm.

Return type

ndarray, shape (n1, n2)

ot.utils.dist(x1, x2=None, metric='sqeuclidean', p=2, w=None)[source]

Compute distance between samples in \(\mathbf{x_1}\) and \(\mathbf{x_2}\)

Note

This function is backend-compatible and will work on arrays from all compatible backends.

Parameters
  • x1 (array-like, shape (n1,d)) – matrix with n1 samples of size d

  • x2 (array-like, shape (n2,d), optional) – matrix with n2 samples of size d (if None then \(\mathbf{x_2} = \mathbf{x_1}\))

  • metric (str | callable, optional) – ‘sqeuclidean’ or ‘euclidean’ on all backends. On numpy the function also accepts from the scipy.spatial.distance.cdist function : ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.

  • p (float, optional) – p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2.

  • w (array-like, rank 1) – Weights for the weighted metrics.

Returns

M – distance matrix computed with given metric

Return type

array-like, shape (n1, n2)

ot.utils.dist0(n, method='lin_square')[source]

Compute standard cost matrices of size (n, n) for OT problems

Parameters
  • n (int) – Size of the cost matrix.

  • method (str, optional) –

    Type of loss matrix chosen from:

    • ’lin_square’ : linear sampling between 0 and n-1, quadratic loss

Returns

M – Distance matrix computed with given metric.

Return type

ndarray, shape (n1, n2)

Examples using ot.utils.dist0

ot.utils.dots(*args)[source]

dots function for multiple matrix multiply

ot.utils.euclidean_distances(X, Y, squared=False)[source]

Considering the rows of \(\mathbf{X}\) (and \(\mathbf{Y} = \mathbf{X}\)) as vectors, compute the distance matrix between each pair of vectors.

Note

This function is backend-compatible and will work on arrays from all compatible backends.

Parameters
  • X (array-like, shape (n_samples_1, n_features)) –

  • Y (array-like, shape (n_samples_2, n_features)) –

  • squared (boolean, optional) – Return squared Euclidean distances.

Returns

distances

Return type

array-like, shape (n_samples_1, n_samples_2)

ot.utils.kernel(x1, x2, method='gaussian', sigma=1, **kwargs)[source]

Compute kernel matrix

ot.utils.label_normalization(y, start=0)[source]

Transform labels to start at a given value

Parameters
  • y (array-like, shape (n, )) – The vector of labels to be normalized.

  • start (int) – Desired value for the smallest label in \(\mathbf{y}\) (default=0)

Returns

y – The input vector of labels normalized according to given start value.

Return type

array-like, shape (n1, )

ot.utils.laplacian(x)[source]

Compute Laplacian matrix

ot.utils.list_to_array(*lst)[source]

Convert a list if in numpy format

ot.utils.parmap(f, X, nprocs='default')[source]

parallel map for multiprocessing. The function has been deprecated and only performs a regular map.

ot.utils.proj_simplex(v, z=1)[source]

Compute the closest point (orthogonal projection) on the generalized (n-1)-simplex of a vector \(\mathbf{v}\) wrt. to the Euclidean distance, thus solving:

\[ \begin{align}\begin{aligned}\mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2\\s.t. \ \gamma^T \mathbf{1} = z\\ \gamma \geq 0\end{aligned}\end{align} \]

If \(\mathbf{v}\) is a 2d array, compute all the projections wrt. axis 0

Note

This function is backend-compatible and will work on arrays from all compatible backends.

Parameters
  • v ({array-like}, shape (n, d)) –

  • z (int, optional) – ‘size’ of the simplex (each vectors sum to z, 1 by default)

Returns

h – Array of projections on the simplex

Return type

ndarray, shape (n, d)

Examples using ot.utils.proj_simplex

ot.utils.tic()[source]

Python implementation of Matlab tic() function

ot.utils.toc(message='Elapsed time : {} s')[source]

Python implementation of Matlab toc() function

ot.utils.toq()[source]

Python implementation of Julia toc() function

ot.utils.unif(n, type_as=None)[source]

Return a uniform histogram of length n (simplex).

Parameters
  • n (int) – number of bins in the histogram

  • type_as (array_like) – array of the same type of the expected output (numpy/pytorch/jax)

Returns

h – histogram of length n such that \(\forall i, \mathbf{h}_i = \frac{1}{n}\)

Return type

array_like (n,)

Classes

class ot.utils.BaseEstimator[source]

Base class for most objects in POT

Code adapted from sklearn BaseEstimator class

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

get_params(deep=True)[source]

Get parameters for this estimator.

Parameters

deep (bool, optional) – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns

params – Parameter names mapped to their values.

Return type

mapping of string to any

set_params(**params)[source]

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Return type

self

Examples using ot.utils.BaseEstimator

class ot.utils.deprecated(extra='')[source]

Decorator to mark a function or class as deprecated.

deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message and the docstring.

Note

To use this with the default value for extra, use empty parentheses:

>>> from ot.deprecation import deprecated  
>>> @deprecated()  
... def some_function(): pass  
Parameters

extra (str) – To be added to the deprecation messages.

Exceptions

UndefinedParameter

Aim at raising an Exception when a undefined parameter is called