pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.

Home Page:https://pytorch.org/torcheval

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add Wasserstein Distance

bobakfb opened this issue Β· comments

commented

πŸš€ The feature

Wasserstein Distance

We'd like to start building out statistical metrics. In this issue we cover the Wasserstein distance, also called the Earth Mover's Distance, which is a measure of the similarity between two distributions.

The Wasserstein distance between two distributions is intuitively the minimum weight of soil (times distance moved) that would need to be moved if the two distributions were represented by two piles of soil. It is not tractable in high dimensions, so we will restrict ourselves to 1 dimensions for this issue.

How To

Make sure to take a look at sci-py's implementation. Also take a look at the quickstart guide which explains how to implement metrics and has a basic implementation of the KS-test statistic which is quite similar to the earth mover's distance (Note: we must implement this function in pure pytorch, the example using scipy in the quickstart is just for simplicity)

Requirements

Implement the wasserstein_distance function:

def wasserstein_1d(x: torch.Tensor, y: torch.Tensor, x_weights: Optional[torch.Tensor] = None, y_weights: Optional[torch.Tensor] = None) -> torch.Tensor:

And the class

class Wasserstein1D(Metric[torch.Tensor]):
  def __init__(self, device: Optional[torch.device] = None)

Class-based implementations keep internal states which can be accumulated as training occurs with calls to update(). The internal states can then be used to calculate the full metric with compute(), which returns the result. In addition, the class interface must have a merge_state() function which explains how to aggregate the internal state variables if they are being updated independently in different processes. For WD, the implementation will probably be similar to AUC, where the internal states of the class based implementation are just a list of all the samples. The only arg the constructor needs is device.

The Functional implementations just takes one set of samples from the first distribution (in x) and one set of samples form the second distribution (y) and returns the earth mover distance. To keep our implementations clean, we have a well defined set of input checks which can be seen below.

<metric>(input, target, *params) #Returns the computed metric for the given predictions (input) and target values

#supporting functions

_<metric>_param_check(...) #Checks the parameters (like number of classes) are valid
_<metric>_update(input, target) #Returns the intermediate variables used to calculate the metric, these should be the same as the state variables for the class based version
_<metric>_update_input_check(...) #Checks the input and target are congruent with the metric definition and parameters (e.g. they are the right shape for the given number of classes)
_<metric>_compute(*state_vars) #Computes the metric given the state variables

Examples:

Functional:

>>> from torcheval.metrics.functional import wasserstein_1d
>>> wasserstein_1d(torch.tensor([0,1,2]), torch.tensor([0,1,1]))
0.33333333333333337
>>> wasserstein_1d(torch.tensor([0,1,2]), torch.tensor([0,1,1]), torch.tensor([1,2,0]), torch.tensor([1,1,1]))
0.0
>>> wasserstein_1d(torch.tensor([0,1,2,2]), torch.tensor([0,1]))
0.75

Classy

>>> from torcheval.metrics import Wasserstein1D
>>> metric = Wasserstein1D()
>>> metric.update(torch.tensor([0,1,2,2]), torch.tensor([0,1]))
>>> metric.compute()
0.75
>>> metric = Wasserstein1D()
>>> metric.update(torch.tensor([0,1,2]), torch.tensor([0,1,1]), torch.tensor([1,2,0]), torch.tensor([1,1,1]))
>>> metric.compute()
0
>>> metric = Wasserstein1D()
>>> metric.update(torch.tensor([0,1,2]), torch.tensor([0,1,1]))
>>> metric.compute()
0.33333333333333337
>>> metric.update(torch.tensor([1,1,1]), torch.tensor([1,1,1]))
>>> metric.compute()
0.16666666666666663

Steps required

Create new wasserstein_1d function in new file fbcode/torcheval/metrics/functional/statistical/wasserstein.py
Create new Wasserstein1D Class in new file fbcode/torcheval/metrics/statistical/wasserstein.py
Add functions and class to init files for easy importing (functional: 1, 2) (class based: 3, 4)
Create new test cases to cover the new functional metric torcheval/tests/metrics/functional/statistical/test_wasserstein.py
Create new test cases to cover the new class metric `torcheval/tests/metrics/statistical/test_wasserstein.py

Testing your changes

Take a look at the contributors guide to see how to run unit tests.

A good suite of tests should have do following:

  • A random data test like this one in binned auprc -- be sure to add the random data getter to the random data module
  • you should use scipy's implementations with random inputs as a 1 to 1 comparison.
  • all input types and shapes should be utilized
  • every arg should be utilized, arg combinations that interact in any way should also be given their own specific test.
  • If your input needs special characteristics to test some cases/arg combinations (e.g. the input must be sorted) be sure to hard code inputs and outputs.
  • The idea is to make sure every feature of the code you wrote is tested.
  • All error endpoints triggered and checked with assertRaisesRegex

Use MetricClassTester to test the class based code. Make sure to utilize multiple updates across different machines by setting the number of elements in the update lists to be a multiple of num_processes (4 per proc is a good target normally)

Please feel free to ask any questions!

commented

Hello, reaching out to get assigned to this. Thanks!

commented

Thanks @GaganCodes Please let me know if you have any Qs!