matt-graham / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/CPU.

Home Page:http://num.pyro.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Build Status Documentation Status Latest Version

NumPyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/CPU.

Docs | Examples | Forum


What is NumPyro?

NumPyro is a small probabilistic programming library built on JAX. It essentially provides a NumPy backend for Pyro, with some minor changes to the inference API and syntax. Since we use JAX, we get autograd and JIT compilation to GPU / CPU for free. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on:

  • Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro primitives like sample and param. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See Examples.
  • Inference algorithms: NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integration step that includes multiple gradient computations. With JAX, we can compose jit and grad to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS). There is also a basic Variational Inference implementation for reparameterized distributions.
  • Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's functional pseudo-random number generator. The design of the distributions module largely follows from PyTorch. A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in torch.distributions. In addition to distributions, constraints and transforms are very useful when operating on distribution classes with bounded support.
  • Effect handlers: Like Pyro, primitives like sample and param can be interpreted with side-effects using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom inference algorithms and inference utilities.

A Simple Example - 8 Schools

Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.

The data is given by:

>>> J = 8
>>> y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

, where y are the treatment effects and sigma the standard error. We build a hierarchical model for the study where we assume that the group-level parameters theta for each school are sampled from a Normal distribution with unknown mean mu and standard deviation tau, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by theta (true effect) and sigma, respectively. This allows us to estimate the population-level parameters mu and tau by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level theta parameters.

>>> # Eight Schools example
... def eight_schools(J, sigma, y=None):
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     with numpyro.plate('J', J):
...         theta = numpyro.sample('theta', dist.Normal(mu, tau))
...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will additionally collect the diverging stat for each sample.

>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('diverging',))

We can print the summary of the MCMC run, and examine if we observed any divergences during inference:

mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      3.94      2.81      3.16      0.03      9.28    114.51      1.06
       tau      3.20      2.97      2.40      0.38      7.28     24.06      1.07
  theta[0]      5.56      5.26      4.10     -1.67     13.52     63.57      1.05
  theta[1]      4.48      4.15      3.26     -2.44     11.25    148.63      1.05
  theta[2]      3.62      4.40      3.26     -3.85     10.75    445.91      1.01
  theta[3]      4.25      4.24      3.24     -2.99     10.68    366.29      1.04
  theta[4]      3.25      3.94      3.29     -3.34      9.84    311.03      1.00
  theta[5]      3.66      4.27      2.77     -2.79     11.06    344.57      1.02
  theta[6]      5.74      4.67      4.34     -1.92     13.25     58.42      1.05
  theta[7]      4.29      4.63      3.23     -2.14     12.37    342.50      1.02

>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))

Number of divergences: 139

The values above 1 for the split Gelman Rubin diagnostic (r_hat) indicates that the chain has not fully converged. The low value for the effective sample size (n_eff), particularly for tau, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a non-centered paramaterization for tau in our model. This is straightforward to do in NumPyro by using a TransformedDistribution instance. Let us rewrite the same model but instead of sampling theta from a Normal(mu, tau), we will instead sample it from a base Normal(0, 1) distribution that is transformed using an AffineTransform. Note that by doing so, NumPyro runs HMC by generating samples for the base Normal(0, 1) distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good!

>>> # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered(J, sigma, y=None):
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     with numpyro.plate('J', J):
...         theta = numpyro.sample('theta', 
...                                dist.TransformedDistribution(dist.Normal(0., 1.),
...                                                             dist.transforms.AffineTransform(mu, tau)))
...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('diverging',))
>>> mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.38      3.04      4.50     -0.92      9.05    876.02      1.00
       tau      3.36      2.89      2.63      0.01      7.56    755.65      1.00
  theta[0]      5.99      5.42      5.44     -1.33     15.13    825.18      1.00
  theta[1]      4.80      4.50      4.78     -1.63     13.01   1114.97      1.00
  theta[2]      3.94      4.63      4.23     -3.41     11.06    914.68      1.00
  theta[3]      4.76      4.62      4.73     -2.31     12.11    958.40      1.00
  theta[4]      3.62      4.66      3.75     -3.87     11.17   1091.53      1.00
  theta[5]      3.92      4.43      4.06     -2.41     11.09   1179.74      1.00
  theta[6]      5.88      4.84      5.34     -1.45     13.11    881.38      1.00
  theta[7]      4.63      4.86      4.64     -3.57     11.80   1065.27      1.00
  
>>> print("Number of divergences: {}".format(sum(mcmc.get_extra_fields()['diverging'])))

Number of divergences: 0

Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. NumPyro provides a Predictive class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The Predictive utility conditions the unobserved mu and tau sites to values drawn from the posterior distribution fron our previous MCMC run, and runs the model forward to generate predictions.

>>> # New School
... def new_school():
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     return numpyro.sample('obs', dist.Normal(mu, tau))


>>> predictive = Predictive(new_school, mcmc.get_samples())
>>> samples_predictive = predictive.get_samples(random.PRNGKey(1))
>>> print(np.mean(samples_predictive['obs']))

4.419043

Installation

Limited Windows Support: Note that NumPyro is untested on Windows, and will require building jaxlib from source. See this JAX issue for more details.

To install NumPyro with a CPU version of JAX, you can use pip:

pip install numpyro

To use NumPyro on the GPU, you will need to first install jax and jaxlib with CUDA support.

You can also install NumPyro from source:

git clone https://github.com/pyro-ppl/numpyro.git
# install jax/jaxlib first for CUDA support
pip install -e .[dev]

Examples

For some examples on specifying models and doing inference in NumPyro:

Users will note that the API for model specification is largely the same as Pyro including the distributions API, by design. The interface for inference algorithms and other utility functions might deviate from Pyro in favor of a more functional style that works better with JAX. e.g. there is no global parameter store or random state.

Future Work

In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:

  • Improving robustness of inference on different models, profiling and performance tuning.
  • More inference algorithms, particularly those that require second order derivaties or use HMC.
  • Integration with Funsor to support inference algorithms with delayed sampling.
  • Supporting more distributions, extending the distributions API, and adding more samplers to JAX.
  • Other areas motivated by Pyro's research goals and application focus, and interest from the community.

About

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/CPU.

http://num.pyro.ai

License:MIT License


Languages

Language:Python 99.7%Language:Makefile 0.2%Language:CSS 0.1%