RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`BatchStrategy` for simultaneous subpopulation `ask`/`tell`/`initialize`

RobertTLange opened this issue · comments

I would like to add a new (abstract) strategy class wrapper(BatchStrategy), which takes a single strategy as input or instantiates one and then performance batched versions of ask, tell and initialize. This provides functionality for executing multiple sub-populations (with same population size) simultaneously in a vectorized/device-parallel fashion.

The high-level brainstorming mindmap looks as follows:

IMG_0941

The workload can be roughly be divided into 3 blocks:

  • Add base BatchStrategy wrapper with core functionality assuming independent subpopulations (no communication between them - see snippet below).
  • Add CommunicationProtocol extension (look for better name) that is applied before/after batched tell. E.g. share top-k members, some info between subpopulations.
  • Add MetaController extension which alters hyparameters similar to how PBT works (e.g. exploit better populations by copying their state, explore hyperparameters).

Let's quickly sketch a rough design idea for the first part:

class BatchStrategy(object):
  def __init__(self, num_dims, popsize, strategy_name, subpopulations):
    self.popsize_per_subpop = int(popsize/subpopulations)
    self.strategy = ...  # set up base strategy functionalities
    # Setup map fct based on availability of devices etc. -> see problem rollouts

  def initialize(self, rng, params):
    batch_rng = jax.random.split(rng, self.subpopulations)
    state = jax.vmap(self.strategy.intialize, ...)(batch_rng, params)
    return state

  def ask(self, rng, state, params):
    batch_rng = jax.random.split(rng, self.subpopulations)
    batch_x, state =  jax.vmap(self.strategy.ask, ...)(batch_rng, state, params)
    # Flatten subpopulation proposals back into flat vector
    # batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
    # x -> Shape: (popsize, num_dims)
    return x, state

  def tell(self, state, params):
    # Reshape flat fitness/search vector into subpopulation array then tell
    # batch_fitness -> Shape: (subpops, popsize_per_subpop)
    # batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
    state = jax.vmap(self.strategy.tell, ...)(batch_x, batch_fitness, state, params)
    return state

CC @DiamonDiva

@DiamonDiva - The basic building blocks are now implemented in the subpops subdirectory. But there are still a couple of open tasks:

  • Write now there are no communication strategies between the subpopulations (only independent case). The Protocol could have two parts -- broadcast_fitness and broadcast_state.
  • Would be awesome to have some small validation experiments as in your 10-d Rosenbrock setup.