google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Expose optimization path of L-BFGS optimizer

miclegr opened this issue · comments

Algorithms such as Pathfinder (paper, blackjax-devs/blackjax#157) need as input all the optimization path, and not only the result, of a L-BFGS optimization run.
Differently from R (and consistently with scipy) jax's jax._src.scipy.optimize._lbfgs._minimize_lbfgs does not expose the path, but just the result. A generalization of such fuction that expose the optimization path would look like this and would allow to implement Pathfinder without rewriting a jax internal function and keeping it in sync with jax main branch over time.

Happy to submit a proper pull request if such feature would be considered for merging.

We are looking to deprecate jax.scipy.minimize and recommend JAXopt instead, so I'd bring @mblondel into this thread (and perhaps eventually we should move the discussion to the jaxopt repository).

JAXopt allows you to do this by repeatedly calling solver.update in your own for loop instead of calling solver.run.

Thanks! That solves the problem, appreciated 👍