mcgrady20150318 / numpyro

Pyro on Numpy

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NumPyro

Build Status

Pyro on Numpy. This uses JAX for autograd and JIT support. This is an early stage experimental library that is under active development, and there are likely to be many changes to the API and internal classes, as the design evolves.

Design Goals

  • Lightweight - We do not intend to reimplement any heavy inference machinery from Pyro, but would like to provide a flexible substrate that can be built upon. We will provide support for Pyro primitives like sample and param which can be interpreted with side-effects using effect handlers. Users should be able to extend this to implement custom inference algorithms, and write their models using the familiar Numpy API.
  • Functional - The API for the inference algorithms and other utility functions may deviate from Pyro in favor of a more functional style that works better with JAX. e.g. no global param store or random state.
  • Fast - Using JAX, we aim to aggressively JIT compile intermediate computations to XLA optimized kernels. We will evaluate JIT compilation, and benchmark runtime for Hamiltonian Monte Carlo.

Longer-term Plans

It is possible that much of this code will end up being absorbed into the Pyro project itself as an alternate Numpy backend.

About

Pyro on Numpy

License:MIT License


Languages

Language:Python 82.9%Language:Jupyter Notebook 17.0%Language:Makefile 0.1%