google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Home Page:https://jaxopt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap`

tare opened this issue · comments

Environment

% pip list|grep jax   
jax                       0.4.20
jaxlib                    0.4.20
jaxopt                    0.8.2

% python --version
Python 3.10.11

Description

ZoomLineSearch under vmap ends up calling failure_diagnostic() even when safe_stepsize > 0. as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is 614dc7b. Below, you will find minimum reproducible examples.

The following code

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

vmap(solve, in_axes=(None, 0))(jnp.zeros(()), ys)

gives the following warnings

WARNING: jaxopt.ZoomLineSearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
INFO: jaxopt.ZoomLineSearch: Iter: 1, Stepsize: 1.0, Decrease error: -0.0, Curvature error: 0.0
WARNING: jaxopt.ZoomLineSearch: The linesearch failed because the provided direction is not a descent direction. The slope (=-0.0) at stepsize=0 should be negative
WARNING: jaxopt.ZoomLineSearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: jaxopt.ZoomLineSearch: Computed stepsize (=1.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.0). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.

Whereas, the following code does not produce any warnings

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

res = map(jit(lambda y: solve(jnp.zeros(()), y)), ys)

Here is a minimal reproducible example illustrating the issue with jax.debug.print, cond, and vmap; the following code

import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond
import jax.debug

def test(x):
    def true_fun(x):
        pass
    def false_fun(x):
        jax.debug.print("{}", x)
    cond(x < 3, true_fun, false_fun, x)

print("map and jit")
map(jit(test), jnp.arange(5))
print("vmap")
vmap(test)(jnp.arange(5))

gives the following output

map and jit
3
4
vmap
0
1
2
3
4

Hello @tare,
Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html.
I'm not sure how we could then have failure diagnostics under vmap.
At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.

Thanks for the quick reply and pointing out #544! I hope that PR gets merged soon.