MMDAgg
This package implements the MMDAgg test for two-sample testing, as proposed in our paper MMD Aggregated Two-Sample Test. The experiments of the paper can be reproduced using the mmdagg-paper repository. The package contains implementations both in Numpy and in Jax, we recommend using the Jax version as it runs 100 times faster after compilation (results from the notebook demo_speed.ipynb in the mmdagg-paper repository). The notebook also contains a demo showing how to use our MMDAgg test. We also provide installation instructions and example code below.
Speed in s | Numpy (CPU) | Jax (CPU) | Jax (GPU) |
---|---|---|---|
MMDAgg | 43.1 | 14.9 | 0.495 |
Requirements
The requirements for the Numpy version are:
python 3.9
numpy
scipy
The requirements for the Jax version are:
python 3.9
jax
jaxlib
Installation
First, we recommend creating a conda environment:
conda create --name mmdagg-env python=3.9
conda activate mmdagg-env
# can be deactivated by running:
# conda deactivate
We then install the required depedencies (Jax installation instructions) by running either:
- for GPU:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # conda install -c conda-forge -c nvidia pip numpy scipy cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
- or, for CPU:
conda install -c conda-forge -c nvidia pip numpy scipy cuda-nvcc jaxlib=0.4.1 jax
Our mmdagg
package can then be installed as follows:
pip install git+https://github.com/antoninschrab/mmdagg.git
MMDAgg
Two-sample testing: Given arrays X of shape mmdagg(X, Y)
returns 0 if the samples X and Y are believed to come from the same distribution, and 1 otherwise.
Jax compilation: The first time the function is evaluated, Jax compiles it. After compilation, it can fastly be evaluated at any other X and Y of the same shape. If the function is given arrays with new shapes, the function is compiled again. For details, check out the demo_speed.ipynb notebook on the mmdagg-paper repository.
# import modules
>>> import jax.numpy as jnp
>>> from jax import random
>>> from mmdagg.jax import mmdagg, human_readable_dict # jax version
>>> # from mmdagg.np import mmdagg
# generate data for two-sample test
>>> key = random.PRNGKey(0)
>>> key, subkey = random.split(key)
>>> subkeys = random.split(subkey, num=2)
>>> X = random.uniform(subkeys[0], shape=(500, 10))
>>> Y = random.uniform(subkeys[1], shape=(500, 10)) + 1
# run MMDAgg test
>>> output = mmdagg(X, Y)
>>> output
Array(1, dtype=int32)
>>> output.item()
1
>>> output, dictionary = mmdagg(X, Y, return_dictionary=True)
>>> output
Array(1, dtype=int32)
>>> human_readable_dict(dictionary)
>>> dictionary
{'MMDAgg test reject': True,
'Single test 1': {'Bandwidth': 1.0,
'MMD': 5.788900671177544e-05,
'MMD quantile': 0.0009193826699629426,
'Kernel IMQ': True,
'Reject': False,
'p-value': 0.41079461574554443,
'p-value threshold': 0.01699146442115307},
...
}
MMDAggInc
For a computationally efficient version of MMDAgg which can run in linear time, check out our package agginc
in the agginc repository.
This package implements the MMDAggInc test (together with HISCAggInc and KSDAggInc) proposed in our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics with reproducible experiments in the agginc-paper repository.
Contact
If you have any issues running our MMDAgg test, please do not hesitate to contact Antonin Schrab.
Affiliations
Centre for Artificial Intelligence, Department of Computer Science, University College London
Gatsby Computational Neuroscience Unit, University College London
Inria London
Bibtex
@article{schrab2021mmd,
author = {Antonin Schrab and Ilmun Kim and M{\'e}lisande Albert and B{\'e}atrice Laurent and Benjamin Guedj and Arthur Gretton},
title = {{MMD} Aggregated Two-Sample Test},
journal = {Journal of Machine Learning Research},
year = {2023},
volume = {24},
number = {194},
pages = {1--81},
url = {http://jmlr.org/papers/v24/21-1289.html}
}
License
MIT License (see LICENSE.md).