This repository contains a small JAX version of Bernstein polynomials as normalizing flows, see original publication here. They are implemented as a combination of distrax and tensorflow probability (jax substrate) objects. Flax is then used to implement a simple probabilistic regression model that fits complex distributions.