Implement as many RL algorithms as possible in JAX
Algorithm | Action Space | Method | File |
---|---|---|---|
DDPG | Continuous | Model-Free | ddpg.py |
TD3 | Continuous | Model-Free | td3.py |
SAC - learned temperature | Continuous | Model-Free | sac.py |
DrQ | Continuous | Model-Free | drq.py |
DroQ | Continuous | Model-Free | sac.py |
Algorithm | Action Space | Method | File |
---|---|---|---|
DQN | Discrete | Model-Free | |
Rainbow | Discrete | Model-Free | |
Planet | Continuous/Discrete | Model-based | |
Dreamer | Continuous/Discrete | Model-based | |
DreamerV2 | Continuous/Discrete | Model-based | |
TRPO | Continuous/Discrete | Model-based | |
PPO | Continuous/Discrete | Model-based | |
DrQv2 | Continuous | Model-free | |
SAC - fixed temperature | Continuous | Model-Free |
- Jax tutorials (https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html, https://jax.readthedocs.io/en/latest/jax-101/index.html)
- JaxRL : repo with Jax implementation of a RL algorithms (https://github.com/ikostrikov/jaxrl/tree/main/jaxrl)