Feature request: Add convenience function for optimization
carlosgmartin opened this issue · comments
Carlos Martin commented
I suggest adding a convenience function evosax.optimize
that handles the most common usage pattern for optimization with evosax. A possible implementation (open to modification) and usage example is shown below:
import evosax
import jax
def optimize(f, x, rng, strat, steps_per_output=1):
def es_step(state, rng):
rngs = jax.random.split(rng, 1 + strat.popsize)
x, state = strat.ask(rngs[0], state)
y = jax.vmap(f)(x, rngs[1:])
state = strat.tell(x, y, state)
return state, y.mean()
@jax.jit
def es_epoch(state, rng):
rngs = jax.random.split(rng, steps_per_output)
state, y = jax.lax.scan(es_step, state, rngs)
return state, y.mean()
@jax.jit
def output_fn(state):
return strat.param_reshaper.reshape_single(state.mean)
rng, subrng = jax.random.split(rng)
state = strat.initialize(subrng, init_mean=x)
while True:
rng, subrng = jax.random.split(rng)
state, y = es_epoch(state, subrng)
yield output_fn(state), y
def f(x, rng):
return x @ x + jax.random.normal(rng) * 0.1
def main():
x = jax.numpy.ones(2)
rng = jax.random.PRNGKey(0)
strat = evosax.OpenES(
popsize=10,
pholder_params=x,
opt_name="adam",
lrate_init=1e-4,
sigma_init=1e-1,
)
for x, y in optimize(f, x, rng, strat):
print(y)
if __name__ == "__main__":
main()
What do you think? I can submit a PR if desired.