araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Question] Speedup compared to SB3

thomashirtz opened this issue · comments

Question

One of the main feature of JAX compared to torch or TF is the speed. Would it be possible to showcase the speedup obtained using SBX compared to SB3 on environment that are fully jitted/not fully jitted ? It would give insights about the speedup increased and if it is worth switching from one library to the other.

Thanks !

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

Would it be possible to showcase the speedup obtained using SBX compared to SB3 on environment that are fully jitted/not fully jitted ?

See https://wandb.ai/openrlbenchmark/sb3/reports/SB3-vs-SBX--VmlldzoyOTczNDg2 for a partial report.
You can already do comparison using openrlbenchmark runs (both SB3 and SBX are logged).

For runtime reports, you can have a look at https://arxiv.org/abs/2310.05808
Please note that you should use cpu device with PPO when not using CNN.

SBX still uses SB3 numpy replay buffer, so it is mostly helpful when speeding up gradient update (which is the current bottleneck), jitting the env will give an additional boost if it is the bottleneck.

Related to #27

EDIT: the other main difference is that SBX only supports a subset of SB3 (for instance, CNN are currently not implemented, but PR are welcomed ;))