Randl / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Optimistix

Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares.

Features include:

  • interoperable solvers: e.g. autoconvert root find problems to least squares problems, then solve using a minimisation algorithm.
  • modular optimisers: e.g. use a BFGS quadratic bowl with a dogleg descent path with a trust region update.
  • using a PyTree as the state.
  • fast compilation and runtimes.
  • interoperability with Optax.
  • all the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.

Installation

pip install optimistix

Requires Python 3.9+ and JAX 0.4.14+ and Equinox 0.11.0+.

Documentation

Available at https://docs.kidger.site/optimistix.

Quick example

import jax.numpy as jnp
import optimistix as optx

# Let's solve the ODE dy/dt=tanh(y(t)) with the implicit Euler method.
# We need to find y1 s.t. y1 = y0 + tanh(y1)dt.

y0 = jnp.array(1.)
dt = jnp.array(0.1)

def fn(y, args):
    return y0 + jnp.tanh(y) * dt

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.fixed_point(fn, solver, y0)
y1 = sol.value  # satisfies y1 == fn(y1)

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{optimistix2024,
    title={Optimistix: modular optimisation in JAX and Equinox},
    author={Jason Rader and Terry Lyons and Patrick Kidger},
    journal={arXiv:2402.09983},
    year={2024},
}

Finally

JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

Credit

Optimistix was primarily built by Jason Rader (@packquickly): Twitter; GitHub; Website.

About

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/

License:Apache License 2.0


Languages

Language:Python 100.0%