google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Home Page:https://jaxopt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

LBFGSB produces NaN for certain conditions

jjyyxx opened this issue · comments

b is NaN in the following snippet.

import jax.numpy as jnp
import jaxopt

def fun(x):
    a, b = x
    return -a

solver = jaxopt.LBFGSB(fun)

init = jnp.array([1.0, 0.0])
upper = jnp.array([1.0, 1.0])
lower = jnp.array([-1.0, -1.0])
result = solver.run(init, bounds=(lower, upper))
a, b = result.params
print(a, b)

Fixed by @vroulet in #493!