`BatchStrategy` for simultaneous subpopulation `ask`/`tell`/`initialize`
RobertTLange opened this issue · comments
I would like to add a new (abstract) strategy class wrapper(BatchStrategy
), which takes a single strategy as input or instantiates one and then performance batched versions of ask
, tell
and initialize
. This provides functionality for executing multiple sub-populations (with same population size) simultaneously in a vectorized/device-parallel fashion.
The high-level brainstorming mindmap looks as follows:
The workload can be roughly be divided into 3 blocks:
- Add base
BatchStrategy
wrapper with core functionality assuming independent subpopulations (no communication between them - see snippet below). - Add
CommunicationProtocol
extension (look for better name) that is applied before/after batchedtell
. E.g. share top-k members, some info between subpopulations. - Add
MetaController
extension which alters hyparameters similar to how PBT works (e.g. exploit better populations by copying their state, explore hyperparameters).
Let's quickly sketch a rough design idea for the first part:
class BatchStrategy(object):
def __init__(self, num_dims, popsize, strategy_name, subpopulations):
self.popsize_per_subpop = int(popsize/subpopulations)
self.strategy = ... # set up base strategy functionalities
# Setup map fct based on availability of devices etc. -> see problem rollouts
def initialize(self, rng, params):
batch_rng = jax.random.split(rng, self.subpopulations)
state = jax.vmap(self.strategy.intialize, ...)(batch_rng, params)
return state
def ask(self, rng, state, params):
batch_rng = jax.random.split(rng, self.subpopulations)
batch_x, state = jax.vmap(self.strategy.ask, ...)(batch_rng, state, params)
# Flatten subpopulation proposals back into flat vector
# batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
# x -> Shape: (popsize, num_dims)
return x, state
def tell(self, state, params):
# Reshape flat fitness/search vector into subpopulation array then tell
# batch_fitness -> Shape: (subpops, popsize_per_subpop)
# batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
state = jax.vmap(self.strategy.tell, ...)(batch_x, batch_fitness, state, params)
return state
CC @DiamonDiva
@DiamonDiva - The basic building blocks are now implemented in the subpops
subdirectory. But there are still a couple of open tasks:
- Write now there are no communication strategies between the subpopulations (only independent case). The
Protocol
could have two parts --broadcast_fitness
andbroadcast_state
. - Would be awesome to have some small validation experiments as in your 10-d Rosenbrock setup.