hamishs / JAX-RL

JAX implementations of various deep reinforcement learning algorithms.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

JAX-RL

JAX implementations of various deep reinforcement learning algorithms.

Main libraries used:

  • JAX - main framework
  • Haiku - neural networks
  • Optax - gradient based optimisation

Algorithms implemented

Algorithms Paper
Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347
Deep Q-Network (DQN) https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
Double Deep Q-Network (DDQN) https://arxiv.org/abs/1509.06461
Deep Recurrent Q-Network (DRQN) https://arxiv.org/abs/1507.06527
Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971

Tabular algorithms

  • Q-learning
  • Double Q-learning
  • SARSA
  • Expected SARSA

Installation

$ pip install git+https://github.com/hamishs/JAX-RL

About

JAX implementations of various deep reinforcement learning algorithms.

License:MIT License


Languages

Language:Python 100.0%