Bug in Colab Notebook
jafluri opened this issue · comments
Hi guys,
First of all, great work, really cool stuff.
I just had a look at this repo and the Colab notebook for safe learning rates. I am note sure if this is the right place to post this, since it is not really a problem with the package itself. But I think there is a bug in the saferate_optimizer
of the afore-mentioned notebook.
def saferate_optimizer(loss, initial_max_eta: float = 1.):
def init_fun(x0):
return (x0, 0., initial_max_eta)
def update_state(state):
x, _, max_eta = state
jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))
safe_eta = safe_learning_rate(x, update_dir, max_eta)
next_x = jax.tree_util.tree_map(
lambda p, v: p + safe_eta*v, x, update_dir
)
next_max_eta = jnp.where(
# If safe_eta is NaN, we cut the learning rate in half.
jnp.logical_or(safe_eta < max_eta / 2, safe_eta > safe_eta),
max_eta / 2,
max_eta * 2
)
return (next_x, safe_eta, next_max_eta)
def get_params(state):
x, _, _ = state
return x
return init_fun, update_state, get_params
In the update_state
function the line jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))
is assigned to nothing, I assume it should be the update_dir
. Currently, the update_dir
is taken from an outer scope.
Additionally, just for readability, I would recommend to add loss
as explicit argument to the bound_loss
function in the beginning of the notebook, as this is also coming from an outer scope.
Thanks for catching this! I fixed both issues (update_dir and the loss argument).