RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Evaluating a population of batched environments with CMA-ES

nhansendev opened this issue · comments

I'm trying to implement a simple Evaluator class to handle the rollouts of batched environments to a population of CMA-ES MLP networks.

Each environment contains a batch of episodes that can be stepped-through in parallel by one network. Each network is paired with one batched environment, creating a population of batched environments to be iterated through.

I've tried implementing this using the brax control example as a reference:

self.rollout_repeats = jax.vmap(self.network, in_axes=(0, None))
self.rollout_pop = jax.vmap(self.rollout_repeats, in_axes=(None, map_dict))

Where map_dict is provided by param_reshaper.vmap_dict

The function is called once each step as:
action = self.rollout_pop(jnp.stack(state), policy_params, rng=rng_net)

The state array provided to self.rollout_pop has the shape population_qty, environment_batch, features. The intent is that the individual networks do not iterate over all of the data, just their environment_batch, features. This would produce an output of shape population_qty, environment_batch, action_dim.

So far I have just received various errors and assertions, even when trying to simplify it with a non-batched environment.
Please let me know what's wrong with the vmaps. Are they even appropriate for this task? The proper usage of the map_dict is definitely a source of confusion here.

Nevermind, I have been able to get a working example of the batching behavior using a nn.Dense layer. Initializing the MLP network is giving me trouble, but this is clearly a generic flax question, not related to evosax.