RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature request: Flexible Parameter Clipping in Evosax Strategy

bheijden opened this issue · comments

Hi,

Currently, parameter clipping is implemented universally with clip_min and clip_max at this code section. This approach applies a uniform clipping mechanism to all parameters before they undergo reshaping.

While this method is generally effective for neural network weights, it appears less optimal for a diverse array of parameters, such as those encountered in physics engine scenarios. These parameters often vary significantly in their nature and scale, thus requiring more tailored clipping strategies.

To enhance flexibility and applicability, I propose considering the application of clipping post the param_reshaper operation. Moreover, utilizing jax.tree_utils.tree_map to individually apply clipping to each leaf, contingent on clip_min and clip_max being specified as a pytrees with the same structure as the params rather than scalar values, could offer a more nuanced and effective parameter optimization process.

WDYT?

edit: on further inspection, this may already be possible by pre-applying the ParamaterReshaper outside of the strategy on both the parameters and the clipping ranges. This would be mean I would be responsible for reshaping the params etc.. on the API boundary between the strategy and my code (that inputs and returns pytrees)?

edit2: Or just simply use the strategy.param_reshaper on the clip ranges to flatten to a single jnp.ndarray and set it as such. Go ahead and close this issue if this is indeed the intended way of dealing with this problem.