RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CMA_ES + jax.lax.scan

pharringtonp19 opened this issue · comments

I was looking to compare with CMA_ES with Differential_ES in this notebook.

When I run CMA_ES using a jax.lax.scan function as my training loop I get the following -->

ConcretizationTypeError                   Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/evosax/strategies/cma_es.py](https://localhost:8080/#) in params_strategy(self)
     44             1 + c_1 / c_mu,
     45             1 + (2 * mu_eff_minus) / (mu_eff + 2),
---> 46             (1 - c_1 - c_mu) / (self.num_dims * c_mu),
     47         )
     48         positive_sum = jnp.sum(weights_prime[weights_prime > 0])

I don't get this error when I run CMA_ES in a standard for loop.

Hi @pharringtonp19,
Thank you for reporting this and giving evosax a try! As far as I can understand the error is resulting from the interaction of jit with a shape dependent definition of the default parameters. CMA-ES has quite some intricate heuristics for defining such, which depend on the dimensionality of your problem and the selected population size.

You can solve your problem by instantiating the strategy and hyperparameters outside of your jitted function call. E.g. via

strategy = CMA_ES(popsize=20, num_dims=1, elite_ratio=0.5)
es_params = strategy.default_params

@partial(jax.jit, static_argnums=(1))
def train(key_num, n_epochs):
  rng, init_state_rng = jax.random.split(jax.random.PRNGKey(key_num))
  state = strategy.initialize(init_state_rng, es_params)

  def update(carry, t):
    rng, state = carry 
    rng, rng_gen = jax.random.split(rng, 2)
    x, state = strategy.ask(rng_gen, state, es_params)   
    fitness = jax.vmap(f)(x)
    state = strategy.tell(jnp.expand_dims(x,1), fitness, state, es_params)
    return (rng, state), (jnp.min(fitness), jnp.mean(fitness), state['best_member'])

  (_, state), results = jax.lax.scan(update, (rng, state), None, length=n_epochs)
  return state, results

final_state, results = train(0, 1000)

Luckily, this issue for now only appears for CMA-ES. I am not sure if there is a better way to handle this and am open to proposals.

Again, thank you 🤗
Rob

@RobertTLange Thanks for the help!