caojiezhang / PyTorchOT

implements optimal transport algorithms in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PyTorchOT

Implements sinkhorn optimal transport algorithms in PyTorch. Currrently there are two versions of the Sinkhorn algorithm implemented: the original and the log-stabilized version.

Example usage:

from ot_pytorch import sink

M = pairwise_distance_matrix()
dist = sink(M, reg=5, cuda=False)

Setting cuda=True enables cuda use.

The examples.py file contains two basic examples.

Example 1:

Let Zi ~ Uniform[0,1], and define the data Xi = (0,Zi), Yi = (θ, Zi), for i=1,...,N and some parameters θ which is varied over [-1,1]. The true optimal transport distance is |θ|. The algorithm yields:

alt text

About

implements optimal transport algorithms in pytorch


Languages

Language:Python 100.0%