LBFGSB produces NaN for certain conditions
jjyyxx opened this issue · comments
b
is NaN
in the following snippet.
import jax.numpy as jnp
import jaxopt
def fun(x):
a, b = x
return -a
solver = jaxopt.LBFGSB(fun)
init = jnp.array([1.0, 0.0])
upper = jnp.array([1.0, 1.0])
lower = jnp.array([-1.0, -1.0])
result = solver.run(init, bounds=(lower, upper))
a, b = result.params
print(a, b)