smonsays / shrink-perturb

Optax implementation of shrink and perturb (Ash & Adams, 2020).

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Shrink-perturb

Optax implementation of shrink and perturb (Ash & Adams, 2020).

Example usage

import haiku as hk
import jax
import jax.numpy as jnp
import optax

from shrink_perturb import shrink_perturb

@hk.without_apply_rng
@hk.transform
def mlp(x):
    return hk.nets.MLP([10, 1])(x)


placeholder_input = jnp.empty((16, 16))
optimizer = optax.chain(
    optax.sgd(learning_rate=0.01),
    # Simply chain `shrink_and_perturb` after the optimizer
    # passing model init_fn closed over input
    shrink_perturb(
        param_init_fn=lambda k: mlp.init(k, placeholder_input),
        shrink=0.9,
        perturb=0.001,
    ),
)

params = mlp.init(jax.random.PRNGKey(0), placeholder_input)
optim_state = optimizer.init(params)
grads = jax.grad(lambda p, x: jnp.sum(mlp.apply(p, x)))(
    params, jnp.ones((16, 16))
)
# Need to pass params to optimizer.update()
params_update = optimizer.update(grads, optim_state, params)

About

Optax implementation of shrink and perturb (Ash & Adams, 2020).

License:MIT License


Languages

Language:Python 100.0%