araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Enhancement] Support for large gradient_steps in SAC

LabChameleon opened this issue · comments

Description:
Using the Jax implementation of SAC with larger values of gradient_steps, e.g. 1000, is very slow to compile. Consider

sbx/sbx/sac/sac.py

Lines 333 to 352 in b8dbac1

@classmethod
@partial(jax.jit, static_argnames=["cls", "gradient_steps"])
def _train(
cls,
gamma: float,
tau: float,
target_entropy: np.ndarray,
gradient_steps: int,
data: ReplayBufferSamplesNp,
policy_delay_indices: flax.core.FrozenDict,
qf_state: RLTrainState,
actor_state: TrainState,
ent_coef_state: TrainState,
key,
):
actor_loss_value = jnp.array(0)
for i in range(gradient_steps):
def slice(x, step=i):

I think the problem lies in unrolling the loop over too many gradient steps. Removing line 334 for not jiting avoids the problem.

To Reproduce

from sbx import SAC
import gymnasium as gym

env = gym.make('Pendulum-v1')
model = SAC('MlpPolicy', env, verbose=1, gradient_steps=1000)

model.learn(100000)

Expected behavior

It should compile fast.

Potential Fix

I adjusted the implementation by moving all computations in the loop body of SAC._train to a new jit'd function gradient_step. Using this function in a JAX fori_loop solves the issue and almost instantly compiles. If you agree with this I would propose a PR with my solution.

### System Info

  • Describe how the library was installed (pip, docker, source, ...): pip
  • sbx-rl version: 0.7.0
  • Python version: 3.11
  • Jax version: 0.4.14
  • Gymnasium version: 0.29

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Hello,
this is actually a known issue...
I tried in the past to replace it (to have something similar to what DQN uses: https://github.com/araffin/sbx/blob/master/sbx/dqn/dqn.py#L162) but I didn't manage to get everything working as before (including speed of training loop once compiled).
However, if you managed (have both fast compilation time and fast runtime), I would be happy to receive a PR for it =)

Hi,
thanks for your reply! I was not aware that you already know the issue. I will have another in-depth look at this and see if my implementation actually offers any improvements over your existing approach. If it is the case I would be happy to make a PR :)