google / autobound

AutoBound automatically computes upper and lower bounds on functions.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Bug in Colab Notebook

jafluri opened this issue · comments

Hi guys,

First of all, great work, really cool stuff.

I just had a look at this repo and the Colab notebook for safe learning rates. I am note sure if this is the right place to post this, since it is not really a problem with the package itself. But I think there is a bug in the saferate_optimizer of the afore-mentioned notebook.

def saferate_optimizer(loss, initial_max_eta: float = 1.):
  def init_fun(x0):
    return (x0, 0., initial_max_eta)

  def update_state(state):
    x, _, max_eta = state
    jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))
    safe_eta = safe_learning_rate(x, update_dir, max_eta)
    next_x = jax.tree_util.tree_map(
        lambda p, v: p + safe_eta*v, x, update_dir
    )
    next_max_eta = jnp.where(
        # If safe_eta is NaN, we cut the learning rate in half.
        jnp.logical_or(safe_eta < max_eta / 2, safe_eta > safe_eta),
        max_eta / 2,
        max_eta * 2
    )
    return (next_x, safe_eta, next_max_eta)

  def get_params(state):
    x, _, _ = state
    return x

  return init_fun, update_state, get_params

In the update_state function the line jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x)) is assigned to nothing, I assume it should be the update_dir. Currently, the update_dir is taken from an outer scope.

Additionally, just for readability, I would recommend to add loss as explicit argument to the bound_loss function in the beginning of the notebook, as this is also coming from an outer scope.

Thanks for catching this! I fixed both issues (update_dir and the loss argument).