RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature request: Add convenience function for optimization

carlosgmartin opened this issue · comments

I suggest adding a convenience function evosax.optimize that handles the most common usage pattern for optimization with evosax. A possible implementation (open to modification) and usage example is shown below:

import evosax
import jax

def optimize(f, x, rng, strat, steps_per_output=1):
    def es_step(state, rng):
        rngs = jax.random.split(rng, 1 + strat.popsize)
        x, state = strat.ask(rngs[0], state)
        y = jax.vmap(f)(x, rngs[1:])
        state = strat.tell(x, y, state)
        return state, y.mean()

    @jax.jit
    def es_epoch(state, rng):
        rngs = jax.random.split(rng, steps_per_output)
        state, y = jax.lax.scan(es_step, state, rngs)
        return state, y.mean()

    @jax.jit
    def output_fn(state):
        return strat.param_reshaper.reshape_single(state.mean)

    rng, subrng = jax.random.split(rng)
    state = strat.initialize(subrng, init_mean=x)
    while True:
        rng, subrng = jax.random.split(rng)
        state, y = es_epoch(state, subrng)
        yield output_fn(state), y

def f(x, rng):
    return x @ x + jax.random.normal(rng) * 0.1

def main():
    x = jax.numpy.ones(2)
    rng = jax.random.PRNGKey(0)
    strat = evosax.OpenES(
        popsize=10,
        pholder_params=x,
        opt_name="adam",
        lrate_init=1e-4,
        sigma_init=1e-1,
    )
    for x, y in optimize(f, x, rng, strat):
        print(y)

if __name__ == "__main__":
    main()

What do you think? I can submit a PR if desired.