Evaluating a population of batched environments with CMA-ES
nhansendev opened this issue · comments
I'm trying to implement a simple Evaluator class to handle the rollouts of batched environments to a population of CMA-ES MLP networks.
Each environment contains a batch of episodes that can be stepped-through in parallel by one network. Each network is paired with one batched environment, creating a population of batched environments to be iterated through.
I've tried implementing this using the brax control example as a reference:
self.rollout_repeats = jax.vmap(self.network, in_axes=(0, None))
self.rollout_pop = jax.vmap(self.rollout_repeats, in_axes=(None, map_dict))
Where map_dict is provided by param_reshaper.vmap_dict
The function is called once each step as:
action = self.rollout_pop(jnp.stack(state), policy_params, rng=rng_net)
The state array provided to self.rollout_pop
has the shape population_qty, environment_batch, features
. The intent is that the individual networks do not iterate over all of the data, just their environment_batch, features
. This would produce an output of shape population_qty, environment_batch, action_dim
.
So far I have just received various errors and assertions, even when trying to simplify it with a non-batched environment.
Please let me know what's wrong with the vmaps. Are they even appropriate for this task? The proper usage of the map_dict is definitely a source of confusion here.
Nevermind, I have been able to get a working example of the batching behavior using a nn.Dense layer. Initializing the MLP network is giving me trouble, but this is clearly a generic flax question, not related to evosax.