araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature Request] Passing custom activation functon in policy_kwargs

paolodelia99 opened this issue Β· comments

πŸš€ Feature

Possibility to pass a flax (from the flax.linen.activation module) activation function when creating a sbx model, through the policy_kwargs argument.

Motivation

In the current implementation of sbx, users are unable to pass custom activation functions when creating a model. This limitation restricts flexibility and may not suit all users' needs.

Pitch

Example:

policy_kwargs = dict(activation_fn=my_custom_activation_fn, net_arch=dict(pi=[64, 64], qf=[64, 64]))

model = TD3("MlpPolicy",
                       env,
                      policy_kwargs=policy_kwargs,
                      verbose=1)

Idea on how to implement it

Add attribute activation_fn to the underlying classes that are composing the policy (like in Critic and Actor in t3d/policy.py)

Hello,
sounds reasonable, would you contribute such feature?

Sure.