google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Seeking guidance for landing spot of `scipy.stats.levy_stable` in Jax

tjhunter opened this issue · comments

Hello,
thank you again for having released this excellent framework. I have implemented the Lévy alpha-stable distribution in JAX (levy-stable-jax). I would like some guidance / alignment before a potential PR:

  • would this implementation (or part of it) fit in Jax?
  • if not, what is the best way to insert a reference to it in the doc?

The alpha-stable distribution is one of the standard 1-dimensional distributions available in scipy (link), and it has many appealing properties for modeling heavy-tailed data such as stock markets. It is currently not in JAX. It is challenging to implement correctly, so all implementations come in two flavors:

  • exact, quadrature-based code such as in scipy or R. It is slow and hard to differentiate.
  • approximate, interpolation-based such as in levy-stable-jax. It is very fast, easy to differentiate, but must include tabulated values (ex: pylevy, Nolan's reference STABLE program)

levy-stable-jax 's implementation has a thorough test suite, is is vectorized and differentiable across all 4 parameters of the distribution. It is especially tuned for fast maximum likelihood estimation, making it quite in line with Jax's general audience. However, because of the tabulated values, the package is at least 5MB, which is probably too much for adding to JAX itself. It could probably be reduced to ~ 500kB, but I would only consider this extra work if there is a reasonable likelihood that this distribution would land in JAX.

In general, I do not think that levy_stable with exact algorithms would likely come to JAX: all known formulas for the pdf rely on truncating infinite series with quadratures. Any implementation would at the very least make jax depend on quadax (and the ability to differentiate across all parameters for all values would still be an open question).