Feature request: Flexible Parameter Clipping in Evosax Strategy
bheijden opened this issue · comments
Hi,
Currently, parameter clipping is implemented universally with clip_min
and clip_max
at this code section. This approach applies a uniform clipping mechanism to all parameters before they undergo reshaping.
While this method is generally effective for neural network weights, it appears less optimal for a diverse array of parameters, such as those encountered in physics engine scenarios. These parameters often vary significantly in their nature and scale, thus requiring more tailored clipping strategies.
To enhance flexibility and applicability, I propose considering the application of clipping post the param_reshaper operation. Moreover, utilizing jax.tree_utils.tree_map
to individually apply clipping to each leaf, contingent on clip_min
and clip_max
being specified as a pytrees with the same structure as the params rather than scalar values, could offer a more nuanced and effective parameter optimization process.
WDYT?
edit: on further inspection, this may already be possible by pre-applying the ParamaterReshaper
outside of the strategy on both the parameters and the clipping ranges. This would be mean I would be responsible for reshaping the params etc.. on the API boundary between the strategy and my code (that inputs and returns pytrees)?
edit2: Or just simply use the strategy.param_reshaper
on the clip ranges to flatten to a single jnp.ndarray
and set it as such. Go ahead and close this issue if this is indeed the intended way of dealing with this problem.