Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap`
tare opened this issue · comments
Environment
% pip list|grep jax
jax 0.4.20
jaxlib 0.4.20
jaxopt 0.8.2
% python --version
Python 3.10.11
Description
ZoomLineSearch
under vmap
ends up calling failure_diagnostic()
even when safe_stepsize > 0.
as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is 614dc7b. Below, you will find minimum reproducible examples.
The following code
import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map
def solve(x, y):
solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
x, _ = solver.run(x, y=y)
return x
x_init = jnp.zeros(())
ys = jnp.arange(1)
vmap(solve, in_axes=(None, 0))(jnp.zeros(()), ys)
gives the following warnings
WARNING: jaxopt.ZoomLineSearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
INFO: jaxopt.ZoomLineSearch: Iter: 1, Stepsize: 1.0, Decrease error: -0.0, Curvature error: 0.0
WARNING: jaxopt.ZoomLineSearch: The linesearch failed because the provided direction is not a descent direction. The slope (=-0.0) at stepsize=0 should be negative
WARNING: jaxopt.ZoomLineSearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: jaxopt.ZoomLineSearch: Computed stepsize (=1.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.0). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.
Whereas, the following code does not produce any warnings
import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map
def solve(x, y):
solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
x, _ = solver.run(x, y=y)
return x
x_init = jnp.zeros(())
ys = jnp.arange(1)
res = map(jit(lambda y: solve(jnp.zeros(()), y)), ys)
Here is a minimal reproducible example illustrating the issue with jax.debug.print
, cond
, and vmap
; the following code
import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond
import jax.debug
def test(x):
def true_fun(x):
pass
def false_fun(x):
jax.debug.print("{}", x)
cond(x < 3, true_fun, false_fun, x)
print("map and jit")
map(jit(test), jnp.arange(5))
print("vmap")
vmap(test)(jnp.arange(5))
gives the following output
map and jit
3
4
vmap
0
1
2
3
4
Hello @tare,
Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html.
I'm not sure how we could then have failure diagnostics under vmap.
At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.
Thanks for the quick reply and pointing out #544! I hope that PR gets merged soon.